# Copyright 2021-2023 Cambridge Quantum Computing Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

Module containing the base class for a lambeq optimizer.

from __future__ import annotations

from abc import ABC, abstractmethod
from import Callable, Iterable, Mapping
from typing import Any

import numpy as np
from numpy.typing import ArrayLike

from import Model

[docs]class Optimizer(ABC): """Optimizer base class."""
[docs] def __init__(self, model: Model, hyperparams: dict[Any, Any], loss_fn: Callable[[Any, Any], float], bounds: ArrayLike | None = None) -> None: """Initialise the optimizer base class. Parameters ---------- model : :py:class:`.QuantumModel` A lambeq model. hyperparams : dict of str to float. A dictionary containing the models hyperparameters. loss_fn : Callable A loss function of form `loss(prediction, labels)`. bounds : ArrayLike, optional The range of each of the model's parameters. """ self.hyperparams = hyperparams self.model = model self.loss_fn = loss_fn self.bounds = bounds self.gradient = np.zeros(len(model.weights))
[docs] @abstractmethod def backward(self, batch: tuple[Iterable[Any], np.ndarray]) -> float: """Calculate the gradients of the loss function. The gradient is calculated with respect to the model parameters. Parameters ---------- batch : tuple of list and numpy.ndarray Current batch. Returns ------- float The calculated loss. """
[docs] @abstractmethod def step(self) -> None: """Perform optimisation step."""
[docs] @abstractmethod def state_dict(self) -> dict[str, Any]: """Return optimizer states as dictionary."""
[docs] @abstractmethod def load_state_dict(self, state: Mapping[str, Any]) -> None: """Load state of the optimizer from the state dictionary."""
[docs] def zero_grad(self) -> None: """Reset the gradients to zero.""" self.gradient *= 0