lambeq.training

class lambeq.training.Checkpoint[source]

Bases: collections.abc.Mapping

Checkpoint class.

Attributes
entriesdict

All data, stored as part of the checkpoint.

__init__() None[source]

Initialise a Checkpoint.

add_many(values: Mapping[str, Any]) None[source]

Adds several values into the checkpoint.

Parameters
valuesMapping from str to any

The values to be added into the checkpoint.

classmethod from_file(path: Union[str, os.PathLike[str]]) lambeq.training.checkpoint.Checkpoint[source]

Load the checkpoint contents from the file.

Parameters
pathstr or PathLike

Path to the checkpoint file.

Raises
FileNotFoundError

If no file is found at the given path.

get(k[, d]) D[k] if k in D, else d.  d defaults to None.
items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
to_file(path: Union[str, os.PathLike[str]]) None[source]

Save entries to a file and deletes the in-memory copy.

Parameters
pathstr or PathLike

Path to the checkpoint file.

values() an object providing a view on D's values
class lambeq.training.Dataset(data: list[Any], targets: list[Any], batch_size: int = 0, shuffle: bool = True)[source]

Bases: object

Dataset class for the training of a lambeq model.

Data is returned in the format of discopy.tensor.Tensor’s backend, which by default is set to NumPy. For example, to access the dataset as PyTorch tensors:

>>> dataset = Dataset(['data1'], [[0, 1, 2, 3]])
>>> with Tensor.backend('pytorch'):
...     print(dataset[0])  # becomes pytorch tensor
('data1', tensor([0, 1, 2, 3]))
>>> print(dataset[0])  # numpy array again
('data1', array([0, 1, 2, 3]))
__init__(data: list[Any], targets: list[Any], batch_size: int = 0, shuffle: bool = True) None[source]

Initialise a Dataset for lambeq training.

Parameters
datalist

Data used for training.

targetslist

List of labels.

batch_sizeint, default: 0

Batch size for batch generation, by default full dataset.

shufflebool, default: True

Enable data shuffling during training.

Raises
ValueError

When ‘data’ and ‘targets’ do not match in size.

static shuffle_data(data: list[Any], targets: list[Any]) tuple[list[Any], list[Any]][source]

Shuffle a given dataset.

Parameters
datalist

List of data points.

targetslist

List of labels.

Returns
Tuple of list and list

The shuffled dataset.

class lambeq.training.Model[source]

Bases: abc.ABC

Model base class.

Attributes
symbolslist of symbols

A sorted list of all Symbols occuring in the data.

weightsCollection

A data structure containing the numeric values of the model’s parameters.

__call__(*args: Any, **kwds: Any) Any[source]

Call self as a function.

__init__() None[source]

Initialise an instance of Model base class.

abstract forward(x: list[Any]) Any[source]

The forward pass of the model.

classmethod from_checkpoint(checkpoint_path: Union[str, os.PathLike[str]], **kwargs: Any) lambeq.training.model.Model[source]

Load the weights and symbols from a training checkpoint.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model[source]

Build model from a list of Diagrams.

Parameters
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

abstract get_diagram_output(diagrams: list[Diagram]) Any[source]

Return the diagram prediction.

Parameters
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

abstract initialise_weights() None[source]

Initialise the weights of the model.

load(checkpoint_path: Union[str, os.PathLike[str]]) None[source]

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: Union[str, os.PathLike[str]]) None[source]

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

class lambeq.training.NumpyModel(use_jit: bool = False)[source]

Bases: lambeq.training.quantum_model.QuantumModel

A lambeq model for an exact classical simulation of a quantum pipeline.

__call__(*args: Any, **kwargs: Any) Any

Call self as a function.

__init__(use_jit: bool = False) None[source]

Initialise an NumpyModel.

Parameters
use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation.

forward(x: list[Diagram]) Any[source]

Perform default forward pass of a lambeq model.

In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method.

Parameters
xlist of Diagram

The Circuits to be evaluated.

Returns
numpy.ndarray

Array containing model’s prediction.

classmethod from_checkpoint(checkpoint_path: Union[str, os.PathLike[str]], **kwargs: Any) lambeq.training.model.Model

Load the weights and symbols from a training checkpoint.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

get_diagram_output(diagrams: list[Diagram]) Union[jnp.ndarray, numpy.ndarray][source]

Return the exact prediction for each diagram.

Parameters
diagramslist of Diagram

The Circuits to be evaluated.

Returns
np.ndarray

Resulting array.

Raises
ValueError

If model.weights or model.symbols are not initialised.

initialise_weights() None

Initialise the weights of the model.

Raises
ValueError

If model.symbols are not initialised.

load(checkpoint_path: Union[str, os.PathLike[str]]) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: Union[str, os.PathLike[str]]) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

weights: np.ndarray
class lambeq.training.Optimizer(model: Model, hyperparams: dict[Any, Any], loss_fn: Callable[[Any, Any], float], bounds: Optional[ArrayLike] = None)[source]

Bases: abc.ABC

Optimizer base class.

__init__(model: Model, hyperparams: dict[Any, Any], loss_fn: Callable[[Any, Any], float], bounds: Optional[ArrayLike] = None) None[source]

Initialise the optimizer base class.

Parameters
modelQuantumModel

A lambeq model.

hyperparamsdict of str to float.

A dictionary containing the models hyperparameters.

loss_fnCallable

A loss function of form loss(prediction, labels).

boundsArrayLike, optional

The range of each of the model’s parameters.

abstract backward(batch: tuple[Iterable[Any], np.ndarray]) float[source]

Calculate the gradients of the loss function.

The gradient is calculated with respect to the model parameters.

Parameters
batchtuple of list and numpy.ndarray

Current batch.

Returns
float

The calculated loss.

abstract load_state_dict(state: Mapping[str, Any]) None[source]

Load state of the optimizer from the state dictionary.

abstract state_dict() dict[str, Any][source]

Return optimizer states as dictionary.

abstract step() None[source]

Perform optimisation step.

zero_grad() None[source]

Reset the gradients to zero.

class lambeq.training.PytorchModel[source]

Bases: lambeq.training.model.Model, torch.nn.modules.module.Module

A lambeq model for the classical pipeline using PyTorch.

T_destination

alias of TypeVar(‘T_destination’, bound=Dict[str, Any])

__call__(*args: Any, **kwds: Any) Any

Call self as a function.

__init__() None[source]

Initialise a PytorchModel.

add_module(name: str, module: Optional[torch.nn.modules.module.Module]) None

Adds a child module to the current module.

The module can be accessed as an attribute using the given name.

Args:
name (str): name of the child module. The child module can be

accessed from this module using the given name

module (Module): child module to be added to the module.

apply(fn: Callable[[torch.nn.modules.module.Module], None]) torch.nn.modules.module.T

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also nn-init-doc).

Args:

fn (Module -> None): function to be applied to each submodule

Returns:

Module: self

Example:

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
bfloat16() torch.nn.modules.module.T

Casts all floating point parameters and buffers to bfloat16 datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

buffers(recurse: bool = True) Iterator[torch.Tensor]

Returns an iterator over module buffers.

Args:
recurse (bool): if True, then yields buffers of this module

and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields:

torch.Tensor: module buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
children() Iterator[torch.nn.modules.module.Module]

Returns an iterator over immediate children modules.

Yields:

Module: a child module

cpu() torch.nn.modules.module.T

Moves all model parameters and buffers to the CPU.

Note

This method modifies the module in-place.

Returns:

Module: self

cuda(device: Optional[Union[int, torch.device]] = None) torch.nn.modules.module.T

Moves all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Args:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

double() torch.nn.modules.module.T

Casts all floating point parameters and buffers to double datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

dump_patches: bool = False
eval() torch.nn.modules.module.T

Sets the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

This is equivalent with self.train(False).

See locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

Module: self

extra_repr() str

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

float() torch.nn.modules.module.T

Casts all floating point parameters and buffers to float datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

forward(x: list[Diagram]) torch.Tensor[source]

Perform default forward pass by contracting tensors.

In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method.

Parameters
xlist of Diagram

The Diagrams to be evaluated.

Returns
torch.Tensor

Tensor containing model’s prediction.

classmethod from_checkpoint(checkpoint_path: Union[str, os.PathLike[str]], **kwargs: Any) lambeq.training.model.Model

Load the weights and symbols from a training checkpoint.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

get_buffer(target: str) torch.Tensor

Returns the buffer given by target if it exists, otherwise throws an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Args:
target: The fully-qualified string name of the buffer

to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

torch.Tensor: The buffer referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not a buffer

get_diagram_output(diagrams: list[Diagram]) torch.Tensor[source]

Contract diagrams using tensornetwork.

Parameters
diagramslist of Diagram

The Diagrams to be evaluated.

Returns
torch.Tensor

Resulting tensor.

Raises
ValueError

If model.weights or model.symbols are not initialised.

get_extra_state() Any

Returns any extra state to include in the module’s state_dict. Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be pickleable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

object: Any extra state to store in the module’s state_dict

get_parameter(target: str) torch.nn.parameter.Parameter

Returns the parameter given by target if it exists, otherwise throws an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Args:
target: The fully-qualified string name of the Parameter

to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

torch.nn.Parameter: The Parameter referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not an nn.Parameter

get_submodule(target: str) torch.nn.modules.module.Module

Returns the submodule given by target if it exists, otherwise throws an error.

For example, let’s say you have an nn.Module A that looks like this:

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Args:
target: The fully-qualified string name of the submodule

to look for. (See above example for how to specify a fully-qualified string.)

Returns:

torch.nn.Module: The submodule referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not an nn.Module

half() torch.nn.modules.module.T

Casts all floating point parameters and buffers to half datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

initialise_weights() None[source]

Initialise the weights of the model.

Raises
ValueError

If model.symbols are not initialised.

ipu(device: Optional[Union[int, torch.device]] = None) torch.nn.modules.module.T

Moves all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

Note

This method modifies the module in-place.

Arguments:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

load(checkpoint_path: Union[str, os.PathLike[str]]) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True)

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Args:
state_dict (dict): a dict containing parameters and

persistent buffers.

strict (bool, optional): whether to strictly enforce that the keys

in state_dict match the keys returned by this module’s state_dict() function. Default: True

Returns:
NamedTuple with missing_keys and unexpected_keys fields:
  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Note:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

modules() Iterator[torch.nn.modules.module.Module]

Returns an iterator over all modules in the network.

Yields:

Module: a module in the network

Note:

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Args:

prefix (str): prefix to prepend to all buffer names. recurse (bool): if True, then yields buffers of this module

and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields:

(str, torch.Tensor): Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>    if name in ['running_var']:
>>>        print(buf.size())
named_children() Iterator[Tuple[str, torch.nn.modules.module.Module]]

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

(str, Module): Tuple containing a name and child module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo: Optional[Set[torch.nn.modules.module.Module]] = None, prefix: str = '', remove_duplicate: bool = True)

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Args:

memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result

or not

Yields:

(str, Module): Tuple of name and module

Note:

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Args:

prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

(str, Parameter): Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>    if name in ['bias']:
>>>        print(param.size())
parameters(recurse: bool = True) Iterator[torch.nn.parameter.Parameter]

Returns an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
register_backward_hook(hook: Callable[[torch.nn.modules.module.Module, Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[None, torch.Tensor]]) torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.

This function is deprecated in favor of register_full_backward_hook() and the behavior of this function will change in future versions.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) None

Adds a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict.

Buffers can be accessed as attributes using given names.

Args:
name (str): name of the buffer. The buffer can be accessed

from this module using the given name

tensor (Tensor or None): buffer to be registered. If None, then operations

that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict.

persistent (bool): whether the buffer is part of this module’s

state_dict.

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook: Callable[[...], None]) torch.utils.hooks.RemovableHandle

Registers a forward hook on the module.

The hook will be called every time after forward() has computed an output. It should have the following signature:

hook(module, input, output) -> None or modified output

The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_forward_pre_hook(hook: Callable[[...], None]) torch.utils.hooks.RemovableHandle

Registers a forward pre-hook on the module.

The hook will be called every time before forward() is invoked. It should have the following signature:

hook(module, input) -> None or modified input

The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned(unless that value is already a tuple).

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_full_backward_hook(hook: Callable[[torch.nn.modules.module.Module, Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[None, torch.Tensor]]) torch.utils.hooks.RemovableHandle

Registers a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module’s load_state_dict is called.

It should have the following signature::

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearning out both missing and unexpected keys will avoid an error.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_module(name: str, module: Optional[torch.nn.modules.module.Module]) None

Alias for add_module().

register_parameter(name: str, param: Optional[torch.nn.parameter.Parameter]) None

Adds a parameter to the module.

The parameter can be accessed as an attribute using given name.

Args:
name (str): name of the parameter. The parameter can be accessed

from this module using the given name

param (Parameter or None): parameter to be added to the module. If

None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict.

requires_grad_(requires_grad: bool = True) torch.nn.modules.module.T

Change if autograd should record operations on parameters in this module.

This method sets the parameters’ requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Args:
requires_grad (bool): whether autograd should record operations on

parameters in this module. Default: True.

Returns:

Module: self

save(checkpoint_path: Union[str, os.PathLike[str]]) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

set_extra_state(state: Any)

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

Args:

state (dict): Extra state from the state_dict

share_memory() torch.nn.modules.module.T

See torch.Tensor.share_memory_()

state_dict(*args, destination=None, prefix='', keep_vars=False)

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Args:
destination (dict, optional): If provided, the state of module will

be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

prefix (str, optional): a prefix added to parameter and buffer

names to compose the keys in state_dict. Default: ''.

keep_vars (bool, optional): by default the Tensor s

returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:
dict:

a dictionary containing a whole state of the module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
symbols: list[Symbol]
to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Args:
device (torch.device): the desired device of the parameters

and buffers in this module

dtype (torch.dtype): the desired floating point or complex dtype of

the parameters and buffers in this module

tensor (torch.Tensor): Tensor whose dtype and device are the desired

dtype and device for all parameters and buffers in this module

memory_format (torch.memory_format): the desired memory

format for 4D parameters and buffers in this module (keyword only argument)

Returns:

Module: self

Examples:

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device: Union[str, torch.device]) torch.nn.modules.module.T

Moves the parameters and buffers to the specified device without copying storage.

Args:
device (torch.device): The desired device of the parameters

and buffers in this module.

Returns:

Module: self

train(mode: bool = True) torch.nn.modules.module.T

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Args:
mode (bool): whether to set training mode (True) or evaluation

mode (False). Default: True.

Returns:

Module: self

training: bool
type(dst_type: Union[torch.dtype, str]) torch.nn.modules.module.T

Casts all parameters and buffers to dst_type.

Note

This method modifies the module in-place.

Args:

dst_type (type or string): the desired type

Returns:

Module: self

weights: torch.nn.ParameterList
xpu(device: Optional[Union[int, torch.device]] = None) torch.nn.modules.module.T

Moves all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

Note

This method modifies the module in-place.

Arguments:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

zero_grad(set_to_none: bool = False) None

Sets gradients of all model parameters to zero. See similar function under torch.optim.Optimizer for more context.

Args:
set_to_none (bool): instead of setting to zero, set the grads to None.

See torch.optim.Optimizer.zero_grad() for details.

class lambeq.training.PytorchTrainer(model: PytorchModel, loss_function: Callable[..., torch.Tensor], epochs: int, optimizer: type[torch.optim.Optimizer] = <class 'torch.optim.adamw.AdamW'>, learning_rate: float = 0.001, device: int = -1, *, optimizer_args: Optional[dict[str, Any]] = None, evaluate_functions: Optional[Mapping[str, _EvalFuncT]] = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: Optional[_StrPathT] = None, from_checkpoint: bool = False, verbose: str = 'text', seed: Optional[int] = None)[source]

Bases: lambeq.training.trainer.Trainer

A PyTorch trainer for the classical pipeline.

__init__(model: PytorchModel, loss_function: Callable[..., torch.Tensor], epochs: int, optimizer: type[torch.optim.Optimizer] = <class 'torch.optim.adamw.AdamW'>, learning_rate: float = 0.001, device: int = -1, *, optimizer_args: Optional[dict[str, Any]] = None, evaluate_functions: Optional[Mapping[str, _EvalFuncT]] = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: Optional[_StrPathT] = None, from_checkpoint: bool = False, verbose: str = 'text', seed: Optional[int] = None) None[source]

Initialise a Trainer instance using the PyTorch backend.

Parameters
modelPytorchModel

A lambeq Model using PyTorch for tensor computation.

loss_functioncallable

A PyTorch loss function from torch.nn.

epochsint

Number of training epochs.

optimizertorch.optim.Optimizer, default: torch.optim.AdamW

A PyTorch optimizer from torch.optim.

learning_ratefloat, default: 1e-3

The learning rate provided to the optimizer for training.

deviceint, default: -1

CUDA device ID used for tensor operation speed-up. A negative value uses the CPU.

optimizer_argsdict of str to Any, optional

Any extra arguments to pass to the optimizer.

evaluate_functionsmapping of str to callable, optional

Mapping of evaluation metric functions from their names. Structure [{“metric”: func}]. Each function takes the prediction “y_hat” and the label “y” as input. The validation step calls “func(y_hat, y)”.

evaluate_on_trainbool, default: True

Evaluate the metrics on the train dataset.

use_tensorboardbool, default: False

Use Tensorboard for visualisation of the training logs.

log_dirstr or PathLike, optional

Location of model checkpoints (and tensorboard log). Default is runs/**CURRENT_DATETIME_HOSTNAME**.

from_checkpointbool, default: False

Starts training from the checkpoint, saved in the log_dir.

verbosestr, default: ‘text’,

See VerbosityLevel for options.

seedint, optional

Random seed.

fit(train_dataset: lambeq.training.dataset.Dataset, val_dataset: Optional[lambeq.training.dataset.Dataset] = None, evaluation_step: int = 1, logging_step: int = 1) None

Fit the model on the training data and, optionally, evaluate it on the validation data.

Parameters
train_datasetDataset

Dataset used for training.

val_datasetDataset, optional

Validation dataset.

evaluation_stepint, default: 1

Sets the intervals at which the metrics are evaluated on the validation dataset.

logging_stepint, default: 1

Sets the intervals at which the training statistics are printed if verbose = ‘text’ (otherwise ignored).

load_training_checkpoint(log_dir: Union[str, os.PathLike[str]]) lambeq.training.checkpoint.Checkpoint

Load model from a checkpoint.

Parameters
log_dirstr or PathLike

The path to the model.lt checkpoint file.

Returns
py:class:.Checkpoint

Checkpoint containing the model weights, symbols and the training history.

Raises
FileNotFoundError

If the file does not exist.

model: PytorchModel
save_checkpoint(save_dict: Mapping[str, Any], log_dir: _StrPathT) None

Save checkpoint.

Parameters
save_dictmapping of str to any

Mapping containing the checkpoint information.

log_dirstr or PathLike

The path where to store the model.lt checkpoint file.

training_step(batch: tuple[list[Any], torch.Tensor]) tuple[torch.Tensor, float][source]

Perform a training step.

Parameters
batchtuple of list and torch.Tensor

Current batch.

Returns
Tuple of torch.Tensor and float

The model predictions and the calculated loss.

validation_step(batch: tuple[list[Any], torch.Tensor]) tuple[torch.Tensor, float][source]

Perform a validation step.

Parameters
batchtuple of list and torch.Tensor

Current batch.

Returns
Tuple of torch.Tensor and float

The model predictions and the calculated loss.

class lambeq.training.QuantumModel[source]

Bases: lambeq.training.model.Model

Quantum Model base class.

Attributes
symbolslist of symbols

A sorted list of all Symbols occurring in the data.

weightsarray

A data structure containing the numeric values of the model parameters

SMOOTHINGfloat

A smoothing constant

__call__(*args: Any, **kwargs: Any) Any[source]

Call self as a function.

__init__() None[source]

Initialise a QuantumModel.

abstract forward(x: list[Diagram]) Any[source]

Compute the forward pass of the model using get_model_output

classmethod from_checkpoint(checkpoint_path: Union[str, os.PathLike[str]], **kwargs: Any) lambeq.training.model.Model

Load the weights and symbols from a training checkpoint.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

abstract get_diagram_output(diagrams: list[Diagram]) Union[jnp.ndarray, np.ndarray][source]

Return the diagram prediction.

Parameters
diagramslist of Diagram

The Circuits to be evaluated.

initialise_weights() None[source]

Initialise the weights of the model.

Raises
ValueError

If model.symbols are not initialised.

load(checkpoint_path: Union[str, os.PathLike[str]]) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: Union[str, os.PathLike[str]]) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

weights: np.ndarray
class lambeq.training.QuantumTrainer(model: QuantumModel, loss_function: Callable[..., float], epochs: int, optimizer: type[Optimizer], optim_hyperparams: dict[str, float], *, optimizer_args: Optional[dict[str, Any]] = None, evaluate_functions: Optional[Mapping[str, _EvalFuncT]] = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: Optional[_StrPathT] = None, from_checkpoint: bool = False, verbose: str = 'text', seed: Optional[int] = None)[source]

Bases: lambeq.training.trainer.Trainer

A Trainer for the quantum pipeline.

__init__(model: QuantumModel, loss_function: Callable[..., float], epochs: int, optimizer: type[Optimizer], optim_hyperparams: dict[str, float], *, optimizer_args: Optional[dict[str, Any]] = None, evaluate_functions: Optional[Mapping[str, _EvalFuncT]] = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: Optional[_StrPathT] = None, from_checkpoint: bool = False, verbose: str = 'text', seed: Optional[int] = None) None[source]

Initialise a Trainer using a quantum backend.

Parameters
modelQuantumModel

A lambeq Model.

loss_functioncallable

A loss function.

epochsint

Number of training epochs

optimizerOptimizer

An optimizer of type lambeq.training.Optimizer.

optim_hyperparamsdict of str to float

The hyperparameters to be used by the optimizer.

optimizer_argsdict of str to Any, optional

Any extra arguments to pass to the optimizer.

evaluate_functionsmapping of str to callable, optional

Mapping of evaluation metric functions from their names. Structure [{“metric”: func}]. Each function takes the prediction “y_hat” and the label “y” as input. The validation step calls “func(y_hat, y)”.

evaluate_on_trainbool, default: True

Evaluate the metrics on the train dataset.

use_tensorboardbool, default: False

Use Tensorboard for visualisation of the training logs.

log_dirstr or PathLike, optional

Location of model checkpoints (and tensorboard log). Default is runs/**CURRENT_DATETIME_HOSTNAME**.

from_checkpointbool, default: False

Starts training from the checkpoint, saved in the log_dir.

verbosestr, default: ‘text’,

See VerbosityLevel for options.

seedint, optional

Random seed.

fit(train_dataset: lambeq.training.dataset.Dataset, val_dataset: Optional[lambeq.training.dataset.Dataset] = None, evaluation_step: int = 1, logging_step: int = 1) None[source]

Fit the model on the training data and, optionally, evaluate it on the validation data.

Parameters
train_datasetDataset

Dataset used for training.

val_datasetDataset, optional

Validation dataset.

evaluation_stepint, default: 1

Sets the intervals at which the metrics are evaluated on the validation dataset.

logging_stepint, default: 1

Sets the intervals at which the training statistics are printed if verbose = ‘text’ (otherwise ignored).

load_training_checkpoint(log_dir: Union[str, os.PathLike[str]]) lambeq.training.checkpoint.Checkpoint

Load model from a checkpoint.

Parameters
log_dirstr or PathLike

The path to the model.lt checkpoint file.

Returns
py:class:.Checkpoint

Checkpoint containing the model weights, symbols and the training history.

Raises
FileNotFoundError

If the file does not exist.

model: QuantumModel
save_checkpoint(save_dict: Mapping[str, Any], log_dir: _StrPathT) None

Save checkpoint.

Parameters
save_dictmapping of str to any

Mapping containing the checkpoint information.

log_dirstr or PathLike

The path where to store the model.lt checkpoint file.

training_step(batch: tuple[list[Any], np.ndarray]) tuple[np.ndarray, float][source]

Perform a training step.

Parameters
batchtuple of list and np.ndarray

Current batch.

Returns
Tuple of np.ndarray and float

The model predictions and the calculated loss.

validation_step(batch: tuple[list[Any], np.ndarray]) tuple[np.ndarray, float][source]

Perform a validation step.

Parameters
batchtuple of list and np.ndarray

Current batch.

Returns
tuple of np.ndarray and float

The model predictions and the calculated loss.

class lambeq.training.SPSAOptimizer(model: QuantumModel, hyperparams: dict[str, float], loss_fn: Callable[[Any, Any], float], bounds: Optional[ArrayLike] = None)[source]

Bases: lambeq.training.optimizer.Optimizer

An Optimizer using SPSA.

SPSA = Simultaneous Perturbation Stochastic Spproximations. See https://ieeexplore.ieee.org/document/705889 for details.

__init__(model: QuantumModel, hyperparams: dict[str, float], loss_fn: Callable[[Any, Any], float], bounds: Optional[ArrayLike] = None) None[source]

Initialise the SPSA optimizer.

The hyperparameters must contain the following key value pairs:

hyperparams = {
    'a': A learning rate parameter, float
    'c': The parameter shift scaling factor, float
    'A': A stability constant, float
}

A good value for ‘A’ is approximately: 0.01 * Num Training steps

Parameters
modelQuantumModel

A lambeq quantum model.

hyperparamsdict of str to float.

A dictionary containing the models hyperparameters.

loss_fnCallable

A loss function of form loss(prediction, labels).

boundsArrayLike, optional

The range of each of the model parameters.

Raises
ValueError

If the hyperparameters are not set correctly, or if the length of bounds does not match the number of the model parameters.

backward(batch: tuple[Iterable[Any], np.ndarray]) float[source]

Calculate the gradients of the loss function.

The gradients are calculated with respect to the model parameters.

Parameters
batchtuple of Iterable and numpy.ndarray

Current batch. Contains an Iterable of diagrams in index 0, and the targets in index 1.

Returns
float

The calculated loss.

load_state_dict(state_dict: Mapping[str, Any]) None[source]

Load state of the optimizer from the state dictionary.

Parameters
state_dictdict

A dictionary containing a snapshot of the optimizer state.

model: QuantumModel
state_dict() dict[str, Any][source]

Return optimizer states as dictionary.

Returns
dict

A dictionary containing the current state of the optimizer.

step() None[source]

Perform optimisation step.

update_hyper_params() None[source]

Update the hyperparameters of the SPSA algorithm.

zero_grad() None

Reset the gradients to zero.

class lambeq.training.TketModel(backend_config: dict[str, Any])[source]

Bases: lambeq.training.quantum_model.QuantumModel

Model based on tket.

This can run either shot-based simulations of a quantum pipeline or experiments run on quantum hardware using tket.

__call__(*args: Any, **kwargs: Any) Any

Call self as a function.

__init__(backend_config: dict[str, Any]) None[source]

Initialise TketModel based on the t|ket> backend.

Other Parameters
backend_configdict

Dictionary containing the backend configuration. Must include the fields backend, compilation and shots.

Raises
KeyError

If backend_config is not provided or has missing fields.

forward(x: list[Diagram]) np.ndarray[source]

Perform default forward pass of a lambeq quantum model.

In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method.

Parameters
xlist of Diagram

The Circuits to be evaluated.

Returns
np.ndarray

Array containing model’s prediction.

classmethod from_checkpoint(checkpoint_path: Union[str, os.PathLike[str]], **kwargs: Any) lambeq.training.model.Model

Load the weights and symbols from a training checkpoint.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

get_diagram_output(diagrams: list[Diagram]) np.ndarray[source]

Return the prediction for each diagram using t|ket>.

Parameters
diagramslist of Diagram

The Circuits to be evaluated.

Returns
np.ndarray

Resulting array.

Raises
ValueError

If model.weights or model.symbols are not initialised.

initialise_weights() None

Initialise the weights of the model.

Raises
ValueError

If model.symbols are not initialised.

load(checkpoint_path: Union[str, os.PathLike[str]]) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: Union[str, os.PathLike[str]]) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

symbols: list[Union[Symbol, SymPySymbol]]
weights: np.ndarray
class lambeq.training.Trainer(model: Model, loss_function: Callable[..., Any], epochs: int, evaluate_functions: Optional[Mapping[str, _EvalFuncT]] = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: Optional[_StrPathT] = None, from_checkpoint: bool = False, verbose: str = 'text', seed: Optional[int] = None)[source]

Bases: abc.ABC

Base class for a lambeq trainer.

__init__(model: Model, loss_function: Callable[..., Any], epochs: int, evaluate_functions: Optional[Mapping[str, _EvalFuncT]] = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: Optional[_StrPathT] = None, from_checkpoint: bool = False, verbose: str = 'text', seed: Optional[int] = None) None[source]

Initialise a lambeq trainer.

Parameters
modelModel

A lambeq Model.

loss_functioncallable

A loss function to compare the prediction to the true label.

epochsint

Number of training epochs.

evaluate_functionsmapping of str to callable, optional

Mapping of evaluation metric functions from their names.

evaluate_on_trainbool, default: True

Evaluate the metrics on the train dataset.

use_tensorboardbool, default: False

Use Tensorboard for visualisation of the training logs.

log_dirstr or PathLike, optional

Location of model checkpoints (and tensorboard log). Default is runs/**CURRENT_DATETIME_HOSTNAME**.

from_checkpointbool, default: False

Starts training from the checkpoint, saved in the log_dir.

verbosestr, default: ‘text’,

See VerbosityLevel for options.

seedint, optional

Random seed.

fit(train_dataset: lambeq.training.dataset.Dataset, val_dataset: Optional[lambeq.training.dataset.Dataset] = None, evaluation_step: int = 1, logging_step: int = 1) None[source]

Fit the model on the training data and, optionally, evaluate it on the validation data.

Parameters
train_datasetDataset

Dataset used for training.

val_datasetDataset, optional

Validation dataset.

evaluation_stepint, default: 1

Sets the intervals at which the metrics are evaluated on the validation dataset.

logging_stepint, default: 1

Sets the intervals at which the training statistics are printed if verbose = ‘text’ (otherwise ignored).

load_training_checkpoint(log_dir: Union[str, os.PathLike[str]]) lambeq.training.checkpoint.Checkpoint[source]

Load model from a checkpoint.

Parameters
log_dirstr or PathLike

The path to the model.lt checkpoint file.

Returns
py:class:.Checkpoint

Checkpoint containing the model weights, symbols and the training history.

Raises
FileNotFoundError

If the file does not exist.

save_checkpoint(save_dict: Mapping[str, Any], log_dir: _StrPathT) None[source]

Save checkpoint.

Parameters
save_dictmapping of str to any

Mapping containing the checkpoint information.

log_dirstr or PathLike

The path where to store the model.lt checkpoint file.

abstract training_step(batch: tuple[list[Any], Any]) tuple[Any, float][source]

Perform a training step.

Parameters
batchtuple of list and any

Current batch.

Returns
Tuple of any and float

The model predictions and the calculated loss.

abstract validation_step(batch: tuple[list[Any], Any]) tuple[Any, float][source]

Perform a validation step.

Parameters
batchtuple of list and any

Current batch.

Returns
Tuple of any and float

The model predictions and the calculated loss.