# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Model=====Module containing the base class for a lambeq model."""from__future__importannotationsfromabcimportABC,abstractmethodfromcollections.abcimportCollectionfromtypingimportAnyfromsympyimportdefault_sort_key,SymbolasSymPySymbolfromlambeq.ansatz.baseimportSymbolfromlambeq.backend.tensorimportDiagramfromlambeq.training.checkpointimportCheckpointfromlambeq.typingimportStrPathT
[docs]classModel(ABC):"""Model base class. Attributes ---------- symbols : list of symbols A sorted list of all :py:class:`Symbols <.Symbol>` occuring in the data. weights : Collection A data structure containing the numeric values of the model's parameters. """
[docs]def__init__(self)->None:"""Initialise an instance of :py:class:`Model` base class."""self.symbols:list[Symbol|SymPySymbol]=[]self.weights:Collection=[]
[docs]@abstractmethoddefinitialise_weights(self)->None:"""Initialise the weights of the model."""
[docs]@classmethoddeffrom_checkpoint(cls,checkpoint_path:StrPathT,**kwargs:Any)->Model:"""Load the weights and symbols from a training checkpoint. Parameters ---------- checkpoint_path : str or PathLike Path that points to the checkpoint file. Other Parameters ---------------- backend_config : dict Dictionary containing the backend configuration for the :py:class:`TketModel`. Must include the fields `'backend'`, `'compilation'` and `'shots'`. """model=cls(**kwargs)checkpoint=Checkpoint.from_file(checkpoint_path)model._load_checkpoint(checkpoint)returnmodel
@abstractmethoddef_load_checkpoint(self,checkpoint:Checkpoint)->None:"""Load the model weights and symbols from a lambeq :py:class:`.Checkpoint`. Parameters ---------- checkpoint : Checkpoint :py:class:`.Checkpoint` containing the model weights, symbols and additional information. """@abstractmethoddef_make_checkpoint(self)->Checkpoint:"""Create checkpoint that contains the model weights and symbols. Returns ------- Checkpoint :py:class:`.Checkpoint` containing the model weights, symbols and additional information. """
[docs]defsave(self,checkpoint_path:StrPathT)->None:"""Create a lambeq :py:class:`.Checkpoint` and save to a path. Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save('my_checkpoint.lt') Parameters ---------- checkpoint_path : str or PathLike Path that points to the checkpoint file. """checkpoint=self._make_checkpoint()checkpoint.to_file(checkpoint_path)
[docs]defload(self,checkpoint_path:StrPathT)->None:"""Load model data from a path pointing to a lambeq checkpoint. Checkpoints that are created by a lambeq :py:class:`Trainer` usually have the extension `.lt`. Parameters ---------- checkpoint_path : str or PathLike Path that points to the checkpoint file. """checkpoint=Checkpoint.from_file(checkpoint_path)self._load_checkpoint(checkpoint)
[docs]@abstractmethoddefget_diagram_output(self,diagrams:list[Diagram])->Any:"""Return the diagram prediction. Parameters ---------- diagrams : list of :py:class:`~lambeq.tensor.Diagram` The tensor or circuit diagrams to be evaluated. """
[docs]@abstractmethoddefforward(self,x:list[Any])->Any:"""The forward pass of the model."""
[docs]@classmethoddeffrom_diagrams(cls,diagrams:list[Diagram],**kwargs:Any)->Model:"""Build model from a list of :py:class:`Diagrams <lambeq.tensor.Diagram>`. Parameters ---------- diagrams : list of :py:class:`~lambeq.tensor.Diagram` The tensor or circuit diagrams to be evaluated. Other Parameters ---------------- backend_config : dict Dictionary containing the backend configuration for the :py:class:`TketModel`. Must include the fields `'backend'`, `'compilation'` and `'shots'`. use_jit : bool, default: False Whether to use JAX's Just-In-Time compilation in :py:class:`NumpyModel`. """model=cls(**kwargs)model.symbols=sorted({symforcircindiagramsforsymincirc.free_symbols},key=default_sort_key)returnmodel