# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""PytorchModel============Module implementing a basic lambeq model based on a Pytorch backend."""from__future__importannotationsfrommathimportsqrtfromtypingimportSequencefromtensornetworkimportAbstractNodefromtensornetworkimportEdgeimporttorchfromlambeq.backend.numerical_backendimportbackendfromlambeq.backend.symbolimportSymbolfromlambeq.backend.tensorimportDiagramfromlambeq.core.utilsimportfast_deepcopyfromlambeq.training.checkpointimportCheckpointfromlambeq.training.modelimportModelfromlambeq.training.tn_path_optimizerimport(CachedTnPathOptimizer,ordered_nodes_contractor,TnPathOptimizer)
[docs]classPytorchModel(Model,torch.nn.Module):"""A lambeq model for the classical pipeline using PyTorch."""weights:torch.nn.ParameterList# type: ignore[assignment]symbols:list[Symbol]tn_path_optimizer:TnPathOptimizer
[docs]def__init__(self,tn_path_optimizer:TnPathOptimizer|None=None)->None:"""Initialise a PytorchModel."""Model.__init__(self)torch.nn.Module.__init__(self)self.tn_path_optimizer=(tn_path_optimizerorCachedTnPathOptimizer())
def_tn_contract(self,nodes:list[AbstractNode],output_edge_order:Sequence[Edge]|None=None,ignore_edge_order:bool=False):returnordered_nodes_contractor(nodes,self.tn_path_optimizer,output_edge_order,ignore_edge_order)def_reinitialise_modules(self)->None:"""Reinitialise all modules in the model."""formoduleinself.modules():try:module.reset_parameters()# type: ignore[operator]except(AttributeError,TypeError):pass
[docs]definitialise_weights(self)->None:"""Initialise the weights of the model. Raises ------ ValueError If `model.symbols` are not initialised. """self._reinitialise_modules()ifnotself.symbols:raiseValueError('Symbols not initialised. Instantiate through ''`PytorchModel.from_diagrams()`.')defmean(size:int)->float:ifsize<6:correction_factor=[float('nan'),3,2.6,2,1.6,1.3][size]else:correction_factor=1/(0.16*size-0.04)returnsqrt(size/3-1/(15-correction_factor))self.weights=torch.nn.ParameterList([(2*torch.rand(w.size)-1)/mean(w.directed_cod)forwinself.symbols])
def_load_checkpoint(self,checkpoint:Checkpoint)->None:"""Load the model weights and symbols from a lambeq :py:class:`.Checkpoint`. Parameters ---------- checkpoint : :py:class:`.Checkpoint` Checkpoint containing the model weights, symbols and additional information. """self.symbols=checkpoint['model_symbols']self.weights=checkpoint['model_weights']self.load_state_dict(checkpoint['model_state_dict'])self.tn_path_optimizer.restore_from_checkpoint(checkpoint)def_make_checkpoint(self)->Checkpoint:"""Create checkpoint that contains the model weights and symbols. Returns ------- :py:class:`.Checkpoint` Checkpoint containing the model weights, symbols and additional information. """checkpoint=Checkpoint()checkpoint.add_many({'model_symbols':self.symbols,'model_weights':self.weights,'model_state_dict':self.state_dict()})checkpoint=self.tn_path_optimizer.store_to_checkpoint(checkpoint)returncheckpoint
[docs]defget_diagram_output(self,diagrams:list[Diagram])->torch.Tensor:"""Contract diagrams using tensornetwork. Parameters ---------- diagrams : list of :py:class:`~lambeq.backend.tensor.Diagram` The :py:class:`Diagrams <lambeq.backend.tensor.Diagram>` to be evaluated. Raises ------ ValueError If `model.weights` or `model.symbols` are not initialised. Returns ------- torch.Tensor Resulting tensor. """importtensornetworkastnparameters={k:vfork,vinzip(self.symbols,self.weights)}diagrams=fast_deepcopy(diagrams)fordiagramindiagrams:forbindiagram.boxes:ifisinstance(b.data,Symbol):try:b.data=parameters[b.data]# type: ignore[attr-defined] # noqa: E501exceptKeyErrorase:raiseKeyError(f'Unknown symbol: {repr(b.data)}')fromewithbackend('pytorch'),tn.DefaultBackend('pytorch'):returntorch.stack([self._tn_contract(*d.to_tn()).tensorfordindiagrams])
[docs]defforward(self,x:list[Diagram])->torch.Tensor:"""Perform default forward pass by contracting tensors. In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method. Parameters ---------- x : list of :py:class:`~lambeq.backend.tensor.Diagram` The :py:class:`Diagrams <lambeq.backend.tensor.Diagram>` to be evaluated. Returns ------- torch.Tensor Tensor containing model's prediction. """returnself.get_diagram_output(x)