# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""QuantumModel============Module containing the base class for a quantum lambeq model."""from__future__importannotationsfromabcimportabstractmethodfromcollections.abcimportIterableimportpicklefromtypingimportAny,TYPE_CHECKINGimportnumpyasnpfromnumpy.typingimportArrayLikefromsympyimportlambdifyfromlambeq.backendimportnumerical_backendfromlambeq.backend.tensorimportDiagramfromlambeq.training.checkpointimportCheckpointfromlambeq.training.modelimportModelifTYPE_CHECKING:fromjaximportnumpyasjnp
[docs]classQuantumModel(Model):"""Quantum Model base class. Attributes ---------- symbols : list of symbols A sorted list of all :py:class:`Symbols <.Symbol>` occurring in the data. weights : array A data structure containing the numeric values of the model parameters """weights:np.ndarray
[docs]def__init__(self)->None:"""Initialise a :py:class:`QuantumModel`."""super().__init__()self._training=Falseself._train_predictions:list[Any]=[]
def_log_prediction(self,y:Any)->None:"""Log a prediction of the model."""self._train_predictions.append(y)def_clear_predictions(self)->None:"""Clear the logged predictions of the model."""self._train_predictions=[]def_normalise_vector(self,predictions:np.ndarray)->np.ndarray:"""Normalise the vector input. Special cases: * scalar value: Returns the absolute value. * zero-vector: Returns the vector as-is. """backend=numerical_backend.get_backend()ret:np.ndarray=backend.abs(predictions)ifpredictions.shape:# Prevent division by 0l1_norm=backend.maximum(1e-9,ret.sum())ret=ret/l1_normreturnret
[docs]definitialise_weights(self)->None:"""Initialise the weights of the model. Raises ------ ValueError If `model.symbols` are not initialised. """ifnotself.symbols:raiseValueError('Symbols not initialised. Instantiate through ''`from_diagrams()`.')self.weights=np.random.rand(len(self.symbols))
def_load_checkpoint(self,checkpoint:Checkpoint)->None:"""Load the model weights and symbols from a lambeq :py:class:`.Checkpoint`. Parameters ---------- checkpoint : :py:class:`.Checkpoint` Checkpoint containing the model weights, symbols and additional information. """self.symbols=checkpoint['model_symbols']self.weights=checkpoint['model_weights']def_make_checkpoint(self)->Checkpoint:"""Create checkpoint that contains the model weights and symbols. Returns ------- :py:class:`.Checkpoint` Checkpoint containing the model weights, symbols and additional information. """checkpoint=Checkpoint()checkpoint.add_many({'model_symbols':self.symbols,'model_weights':self.weights})returncheckpointdef_fast_subs(self,diagrams:list[Diagram],weights:Iterable[ArrayLike])->list[Diagram]:"""Substitute weights into a list of parameterised circuit."""parameters={k:vfork,vinzip(self.symbols,weights)}diagrams=pickle.loads(pickle.dumps(diagrams))# does fast deepcopyfordiagramindiagrams:forbindiagram.boxes:ifb.free_symbols:whilehasattr(b,'controlled'):b=b.controlledsyms,values=[],[]forsyminb.free_symbols:syms.append(sym)try:values.append(parameters[sym])exceptKeyErrorase:raiseKeyError(f'Unknown symbol: {repr(sym)}')fromeb.data=lambdify(syms,b.data)(*values)# type: ignore[attr-defined] # noqa: E501# The name of this box isnt updated correctlydelb.free_symbolsreturndiagrams
[docs]@abstractmethoddefget_diagram_output(self,diagrams:list[Diagram])->jnp.ndarray|np.ndarray:"""Return the diagram prediction. Parameters ---------- diagrams : list of :py:class:`~lambeq.backend.quantum.Diagram` The :py:class:`Circuits <lambeq.backend.quantum.Diagram>` to be evaluated. """