# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""NumpyModel==========Module implementing a lambeq model for an exact classical simulation ofa quantum pipeline.In contrast to the shot-based :py:class:`TketModel`, the state vectorsare calculated classically and stored such that the complex vectorsdefining the quantum states are accessible. The results of thecalculations are exact i.e. noiseless and not shot-based."""from__future__importannotationsfromcollections.abcimportCallable,IterablefromtypingimportAny,TYPE_CHECKINGimportnumpyfromnumpy.typingimportArrayLikefromlambeq.backendimportnumerical_backendfromlambeq.backend.quantumimportDiagramasCircuitfromlambeq.backend.tensorimportDiagramfromlambeq.training.quantum_modelimportQuantumModelifTYPE_CHECKING:fromjaximportnumpyasjnp
[docs]classNumpyModel(QuantumModel):"""A lambeq model for an exact classical simulation of a quantum pipeline."""
[docs]def__init__(self,use_jit:bool=False)->None:"""Initialise an NumpyModel. Parameters ---------- use_jit : bool, default: False Whether to use JAX's Just-In-Time compilation. """super().__init__()self.use_jit=use_jitself.lambdas:dict[Diagram,Callable[...,Any]]={}
def_get_lambda(self,diagram:Diagram)->Callable[[Any],Any]:"""Get lambda function that evaluates the provided diagram. Raises ------ ValueError If `model.symbols` are not initialised. """fromjaximportjitimporttensornetworkastnifnotself.symbols:raiseValueError('Symbols not initialised. Instantiate through ''`NumpyModel.from_diagrams()`.')ifdiagraminself.lambdas:returnself.lambdas[diagram]defdiagram_output(x:Iterable[ArrayLike])->ArrayLike:with(numerical_backend.backend('jax')asbackend,tn.DefaultBackend('jax')):sub_circuit=self._fast_subs([diagram],x)[0]result=tn.contractors.auto(*sub_circuit.to_tn()).tensor# square amplitudes to get probabilties for pure circuitsassertisinstance(sub_circuit,Circuit)ifnotsub_circuit.is_mixed:result=backend.abs(result)**2returnself._normalise_vector(result)self.lambdas[diagram]=jit(diagram_output)returnself.lambdas[diagram]
[docs]defget_diagram_output(self,diagrams:list[Diagram])->jnp.ndarray|numpy.ndarray:"""Return the exact prediction for each diagram. Parameters ---------- diagrams : list of :py:class:`~lambeq.tensor.Diagram` The :py:class:`Circuits <lambeq.quantum.circuit.Circuit>` to be evaluated. Raises ------ ValueError If `model.weights` or `model.symbols` are not initialised. Returns ------- np.ndarray Resulting array. """importtensornetworkastniflen(self.weights)==0ornotself.symbols:raiseValueError('Weights and/or symbols not initialised. ''Instantiate through ''`NumpyModel.from_diagrams()` first, ''then call `initialise_weights()`, or load ''from pre-trained checkpoint.')ifself.use_jit:fromjaximportnumpyasjnplambdified_diagrams=[self._get_lambda(d)fordindiagrams]ifhasattr(self.weights,'filled'):self.weights=self.weights.filled()res:jnp.ndarray=jnp.array([diag_f(self.weights)fordiag_finlambdified_diagrams])returnresdiagrams=self._fast_subs(diagrams,self.weights)results=[]fordindiagrams:assertisinstance(d,Circuit)result=tn.contractors.auto(*d.to_tn()).tensor# square amplitudes to get probabilties for pure circuitsifnotd.is_mixed:result=numpy.abs(result)**2results.append(self._normalise_vector(result))returnnumpy.array(results)
[docs]defforward(self,x:list[Diagram])->Any:"""Perform default forward pass of a lambeq model. In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method. Parameters ---------- x : list of :py:class:`~lamebq.tensor.Diagram` The :py:class:`Circuits <lambeq.quantum.circuit.Circuit>` to be evaluated. Returns ------- numpy.ndarray Array containing model's prediction. """returnself.get_diagram_output(x)