Source code for lambeq.training.numpy_model

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
NumpyModel
==========
Module implementing a lambeq model for an exact classical simulation of
a quantum pipeline.

In contrast to the shot-based :py:class:`TketModel`, the state vectors
are calculated classically and stored such that the complex vectors
defining the quantum states are accessible. The results of the
calculations are exact i.e. noiseless and not shot-based.

"""
from __future__ import annotations

from collections.abc import Callable, Iterable
from typing import Any, TYPE_CHECKING

import numpy
from numpy.typing import ArrayLike

from lambeq.backend import numerical_backend
from lambeq.backend.quantum import Diagram as Circuit
from lambeq.backend.tensor import Diagram
from lambeq.training.quantum_model import QuantumModel


if TYPE_CHECKING:
    from jax import numpy as jnp


[docs]class NumpyModel(QuantumModel): """A lambeq model for an exact classical simulation of a quantum pipeline."""
[docs] def __init__(self, use_jit: bool = False) -> None: """Initialise an NumpyModel. Parameters ---------- use_jit : bool, default: False Whether to use JAX's Just-In-Time compilation. """ super().__init__() self.use_jit = use_jit self.lambdas: dict[Diagram, Callable[..., Any]] = {}
def _get_lambda(self, diagram: Diagram) -> Callable[[Any], Any]: """Get lambda function that evaluates the provided diagram. Raises ------ ValueError If `model.symbols` are not initialised. """ from jax import jit import tensornetwork as tn if not self.symbols: raise ValueError('Symbols not initialised. Instantiate through ' '`NumpyModel.from_diagrams()`.') if diagram in self.lambdas: return self.lambdas[diagram] def diagram_output(x: Iterable[ArrayLike]) -> ArrayLike: with (numerical_backend.backend('jax') as backend, tn.DefaultBackend('jax')): sub_circuit = self._fast_subs([diagram], x)[0] result = tn.contractors.auto(*sub_circuit.to_tn()).tensor # square amplitudes to get probabilties for pure circuits assert isinstance(sub_circuit, Circuit) if not sub_circuit.is_mixed: result = backend.abs(result) ** 2 return self._normalise_vector(result) self.lambdas[diagram] = jit(diagram_output) return self.lambdas[diagram]
[docs] def get_diagram_output( self, diagrams: list[Diagram] ) -> jnp.ndarray | numpy.ndarray: """Return the exact prediction for each diagram. Parameters ---------- diagrams : list of :py:class:`~lambeq.tensor.Diagram` The :py:class:`Circuits <lambeq.quantum.circuit.Circuit>` to be evaluated. Raises ------ ValueError If `model.weights` or `model.symbols` are not initialised. Returns ------- np.ndarray Resulting array. """ import tensornetwork as tn if len(self.weights) == 0 or not self.symbols: raise ValueError('Weights and/or symbols not initialised. ' 'Instantiate through ' '`NumpyModel.from_diagrams()` first, ' 'then call `initialise_weights()`, or load ' 'from pre-trained checkpoint.') if self.use_jit: from jax import numpy as jnp lambdified_diagrams = [self._get_lambda(d) for d in diagrams] if hasattr(self.weights, 'filled'): self.weights = self.weights.filled() res: jnp.ndarray = jnp.array([diag_f(self.weights) for diag_f in lambdified_diagrams]) return res diagrams = self._fast_subs(diagrams, self.weights) results = [] for d in diagrams: assert isinstance(d, Circuit) result = tn.contractors.auto(*d.to_tn()).tensor # square amplitudes to get probabilties for pure circuits if not d.is_mixed: result = numpy.abs(result) ** 2 results.append(self._normalise_vector(result)) return numpy.array(results)
[docs] def forward(self, x: list[Diagram]) -> Any: """Perform default forward pass of a lambeq model. In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method. Parameters ---------- x : list of :py:class:`~lamebq.tensor.Diagram` The :py:class:`Circuits <lambeq.quantum.circuit.Circuit>` to be evaluated. Returns ------- numpy.ndarray Array containing model's prediction. """ return self.get_diagram_output(x)