Source code for

# Copyright 2021-2023 Cambridge Quantum Computing Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

Module implementing the Simultaneous Perturbation Stochastic
Approximation optimizer.

from __future__ import annotations

from import Callable, Iterable, Mapping
from typing import Any

import numpy as np
from numpy.typing import ArrayLike

from lambeq.core.utils import flatten
from import Optimizer
from import QuantumModel

[docs]class SPSAOptimizer(Optimizer): """An Optimizer using SPSA. SPSA = Simultaneous Perturbation Stochastic Spproximations. See for details. """ model : QuantumModel
[docs] def __init__(self, model: QuantumModel, hyperparams: dict[str, float], loss_fn: Callable[[Any, Any], float], bounds: ArrayLike | None = None) -> None: """Initialise the SPSA optimizer. The hyperparameters must contain the following key value pairs:: hyperparams = { 'a': A learning rate parameter, float 'c': The parameter shift scaling factor, float 'A': A stability constant, float } A good value for 'A' is approximately: 0.01 * Num Training steps Parameters ---------- model : :py:class:`.QuantumModel` A lambeq quantum model. hyperparams : dict of str to float. A dictionary containing the models hyperparameters. loss_fn : Callable A loss function of form `loss(prediction, labels)`. bounds : ArrayLike, optional The range of each of the model parameters. Raises ------ ValueError If the hyperparameters are not set correctly, or if the length of `bounds` does not match the number of the model parameters. """ fields = ('a', 'c', 'A') if any(field not in hyperparams for field in fields): raise KeyError('Missing arguments in hyperparameter dict' f'configuation. Must contain {fields}.') super().__init__(model, hyperparams, loss_fn, bounds) self.alpha = 0.602 self.gamma = 0.101 self.current_sweep = 1 self.A = self.hyperparams['A'] self.a = self.hyperparams['a'] self.c = self.hyperparams['c'] self.ak = self.a/(self.current_sweep+self.A)**self.alpha = self.c/(self.current_sweep)**self.gamma self.project: Callable[[np.ndarray], np.ndarray] if self.bounds is None: self.project = lambda _: _ else: bds = np.asarray(bounds) if len(bds) != len(self.model.weights): raise ValueError('Length of `bounds` must be the same as the ' 'number of the model parameters') self.project = lambda x: x.clip(bds[:, 0], bds[:, 1])
[docs] def backward( self, batch: tuple[Iterable[Any], np.ndarray]) -> float: """Calculate the gradients of the loss function. The gradients are calculated with respect to the model parameters. Parameters ---------- batch : tuple of Iterable and numpy.ndarray Current batch. Contains an Iterable of diagrams in index 0, and the targets in index 1. Returns ------- float The calculated loss. """ diagrams, targets = batch diags_gen = flatten(diagrams) # the symbolic parameters parameters = self.model.symbols x = self.model.weights relevant_params = set.union(*[diag.free_symbols for diag in diags_gen]) mask = np.array([1 if sym in relevant_params else 0 for sym in parameters]) # the perturbations delta = np.random.choice([-1, 1], size=len(x)) # calculate gradient xplus = self.project(x + * delta) self.model.weights = xplus y0 = self.model(diagrams) loss0 = self.loss_fn(y0, targets) xminus = self.project(x - * delta) self.model.weights = xminus y1 = self.model(diagrams) loss1 = self.loss_fn(y1, targets) if self.bounds is None: grad = (loss0 - loss1) / (2 * * delta) else: grad = (loss0 - loss1) / (xplus - xminus) self.gradient += grad * mask # restore parameter value self.model.weights = x loss = (loss0 + loss1) / 2 return loss
[docs] def step(self) -> None: """Perform optimisation step.""" self.model.weights -= self.gradient * self.ak self.model.weights = self.project(self.model.weights) self.update_hyper_params() self.zero_grad()
[docs] def update_hyper_params(self) -> None: """Update the hyperparameters of the SPSA algorithm.""" self.current_sweep += 1 a_decay = (self.current_sweep+self.A)**self.alpha c_decay = (self.current_sweep)**self.gamma self.ak = self.hyperparams['a']/a_decay = self.hyperparams['c']/c_decay
[docs] def state_dict(self) -> dict[str, Any]: """Return optimizer states as dictionary. Returns ------- dict A dictionary containing the current state of the optimizer. """ return {'A': self.A, 'a': self.a, 'c': self.c, 'ak': self.ak, 'ck':, 'current_sweep': self.current_sweep}
[docs] def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: """Load state of the optimizer from the state dictionary. Parameters ---------- state_dict : dict A dictionary containing a snapshot of the optimizer state. """ self.A = state_dict['A'] self.a = state_dict['a'] self.c = state_dict['c'] self.ak = state_dict['ak'] = state_dict['ck'] self.current_sweep = state_dict['current_sweep']