Source code for lambeq.training.loss

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Loss Functions
==============
Module containing loss functions to train lambeq's quantum models.

"""
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from jax import numpy as jnp
    from types import ModuleType


[docs]class LossFunction(ABC): """Loss function base class. Attributes ---------- backend : ModuleType The module to use for array numerical functions. Either numpy or jax.numpy. """
[docs] def __init__(self, use_jax: bool = False) -> None: """Initialise a loss function. Parameters ---------- use_jax : bool, default: False Whether to use the Jax variant of numpy as `backend`. """ self.backend: ModuleType if use_jax: from jax import numpy as jnp self.backend = jnp else: self.backend = np
def _match_shapes(self, y1: np.ndarray | jnp.ndarray, y2: np.ndarray | jnp.ndarray) -> None: if y1.shape != y2.shape: raise ValueError('Provided arrays must be of equal shape. Got ' f'arrays of shape {y1.shape} and {y2.shape}.') def _smooth_and_normalise(self, y: np.ndarray | jnp.ndarray, epsilon: float ) -> np.ndarray | jnp.ndarray: y_smoothed = y + epsilon l1_norms: np.ndarray | jnp.ndarray = self.backend.linalg.norm( y_smoothed, ord=1, axis=1, keepdims=True) return y_smoothed / l1_norms
[docs] @abstractmethod def calculate_loss(self, y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) -> float: """Calculate value of loss function."""
[docs] def __call__(self, y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) -> float: return self.calculate_loss(y_pred, y_true)
[docs]class CrossEntropyLoss(LossFunction): """Multiclass cross-entropy loss function. Parameters ---------- y_pred: np.ndarray or jnp.ndarray Predicted labels from model. Expected to be of shape [batch_size, n_classes], where each row is a probability distribution. y_true: np.ndarray or jnp.ndarray Ground truth labels. Expected to be of shape [batch_size, n_classes], where each row is a one-hot vector. """
[docs] def __init__(self, use_jax: bool = False, epsilon: float = 1e-9) -> None: """Initialise a multiclass cross-entropy loss function. Parameters ---------- use_jax : bool, default: False Whether to use the Jax variant of numpy. epsilon : float, default: 1e-9 Smoothing constant used to prevent calculating log(0). """ self._epsilon = epsilon super().__init__(use_jax)
[docs] def calculate_loss(self, y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) -> float: """Calculate value of CE loss function.""" self._match_shapes(y_pred, y_true) y_pred_smoothed = self._smooth_and_normalise(y_pred, self._epsilon) entropies = y_true * self.backend.log(y_pred_smoothed) loss_val: float = -self.backend.sum(entropies) / len(y_true) return loss_val
[docs]class BinaryCrossEntropyLoss(CrossEntropyLoss): """Binary cross-entropy loss function. Parameters ---------- y_pred: np.ndarray or jnp.ndarray Predicted labels from model. When `sparse` is `False`, expected to be of shape [batch_size, 2], where each row is a probability distribution. When `sparse` is `True`, expected to be of shape [batch_size, ] where each element indicates P(1). y_true: np.ndarray or jnp.ndarray Ground truth labels. When `sparse` is `False`, expected to be of shape [batch_size, 2], where each row is a one-hot vector. When `sparse` is `True`, expected to be of shape [batch_size, ] where each element is an integer indicating class label. """
[docs] def __init__(self, sparse: bool = False, use_jax: bool = False, epsilon: float = 1e-9) -> None: """Initialise a binary cross-entropy loss function. Parameters ---------- sparse : bool, default: False If True, each input element indicates P(1), else the probability distribution over classes is expected. use_jax : bool, default: False Whether to use the Jax variant of numpy. epsilon : float, default: 1e-9 Smoothing constant used to prevent calculating log(0). """ self._sparse = sparse super().__init__(use_jax, epsilon)
[docs] def calculate_loss(self, y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) -> float: """Calculate value of BCE loss function.""" if self._sparse: # For numerical stability, it is convenient to reshape the # sparse input to a dense representation. self._match_shapes(y_pred, y_true) y_pred_dense = self.backend.stack((1 - y_pred, y_pred)).T y_true_dense = self.backend.stack((1 - y_true, y_true)).T return super().calculate_loss(y_pred_dense, y_true_dense) else: return super().calculate_loss(y_pred, y_true)
[docs]class MSELoss(LossFunction): """Mean squared error loss function. Parameters ---------- y_pred: np.ndarray or jnp.ndarray Predicted values from model. Shape must match y_true. y_true: np.ndarray or jnp.ndarray Ground truth values. """
[docs] def calculate_loss(self, y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) -> float: """Calculate value of MSE loss function.""" self._match_shapes(y_pred, y_true) return float(self.backend.mean((y_pred - y_true) ** 2))