Source code for lambeq.training.pytorch_quantum_model
# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""PytorchQuantumModel===================Module implementing a basic lambeq model based on a Pytorch backendfor training quantum circuits with Pytorch automatic gradients."""from__future__importannotationsimporttorchfromlambeq.backend.numerical_backendimportbackendfromlambeq.backend.quantumimportDiagramasCircuitfromlambeq.backend.tensorimportDiagramfromlambeq.training.pytorch_modelimportPytorchModelfromlambeq.training.quantum_modelimportQuantumModelfromlambeq.training.tn_path_optimizerimportTnPathOptimizer
[docs]classPytorchQuantumModel(PytorchModel,QuantumModel):"""A lambeq model for the quantum pipeline using PyTorch with automatic gradient tracking."""weights:torch.nn.Parameter# type: ignore[assignment]
[docs]def__init__(self,tn_path_optimizer:TnPathOptimizer|None=None)->None:"""Initialise a PytorchQuantumModel."""PytorchModel.__init__(self,tn_path_optimizer)QuantumModel.__init__(self)
[docs]definitialise_weights(self)->None:"""Initialise the weights of the model. Raises ------ ValueError If `model.symbols` are not initialised. """self._reinitialise_modules()ifnotself.symbols:raiseValueError('Symbols not initialised. Instantiate through ''`PytorchQuantumModel.from_diagrams()`.')self.weights=torch.nn.Parameter(torch.rand(len(self.symbols)))
[docs]defget_diagram_output(self,diagrams:list[Diagram])->torch.Tensor:"""Contract diagrams using tensornetwork. Parameters ---------- diagrams : list of :py:class:`~lambeq.backend.tensor.Diagram` The :py:class:`Diagrams <lambeq.backend.tensor.Diagram>` to be evaluated. Raises ------ ValueError If `model.weights` or `model.symbols` are not initialised. Returns ------- torch.Tensor Resulting tensor. """importtensornetworkastndiagrams=self._fast_subs(diagrams,self.weights)withbackend('pytorch'),tn.DefaultBackend('pytorch'):results=[]fordindiagrams:assertisinstance(d,Circuit)nodes,edges=d.to_tn()# Ensure uniform tensor dtypes for contraction.dominant_dtype=torch.boolfornodeinnodes:dominant_dtype=torch.promote_types(dominant_dtype,node.tensor.dtype)fornodeinnodes:ifnode.tensor.dtype!=dominant_dtype:node.tensor=node.tensor.to(dominant_dtype)result=self._tn_contract(nodes,edges).tensorifnotd.is_mixed:result=torch.square(torch.abs(result))results.append(self._normalise_vector(result))returntorch.stack(results)
[docs]defforward(self,x:list[Diagram])->torch.Tensor:"""Perform default forward pass by contracting tensors. In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method. Parameters ---------- x : list of :py:class:`~lambeq.backend.tensor.Diagram` The :py:class:`Diagrams <lambeq.backend.tensor.Diagram>` to be evaluated. Returns ------- torch.Tensor Tensor containing model's prediction. """returnself.get_diagram_output(x)