Source code for

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

A module containing a Dataset class for training lambeq models.

from __future__ import annotations

from import Iterator
from math import ceil
import random
from typing import Any

from lambeq.backend.numerical_backend import get_backend

[docs]class Dataset: """Dataset class for the training of a lambeq model. Data is returned in the format of lambeq's numerical backend, which by default is set to NumPy. For example, to access the dataset as PyTorch tensors: >>> from lambeq.backend import numerical_backend >>> dataset = Dataset(['data1'], [[0, 1, 2, 3]]) >>> with numerical_backend.backend('pytorch'): ... print(dataset[0]) # becomes pytorch tensor ('data1', tensor([0, 1, 2, 3])) >>> print(dataset[0]) # numpy array again ('data1', array([0, 1, 2, 3])) """
[docs] def __init__(self, data: list[Any], targets: list[Any], batch_size: int = 0, shuffle: bool = True) -> None: """Initialise a Dataset for lambeq training. Parameters ---------- data : list Data used for training. targets : list List of labels. batch_size : int, default: 0 Batch size for batch generation, by default full dataset. shuffle : bool, default: True Enable data shuffling during training. Raises ------ ValueError When 'data' and 'targets' do not match in size. """ if len(data) != len(targets): raise ValueError('Lengths of `data` and `targets` differ.') = data self.targets = targets self.batch_size = batch_size self.shuffle = shuffle if self.batch_size == 0: self.batch_size = len( self.batches_per_epoch = ceil(len( / self.batch_size)
def __getitem__(self, index: int | slice) -> tuple[Any, Any]: """Get a single item or a subset from the dataset.""" x =[index] y = self.targets[index] return x, get_backend().array(y) def __len__(self) -> int: return len( def __iter__(self) -> Iterator[tuple[list[Any], Any]]: """Iterate over data batches. Yields ------ Tuple of list and any An iterator that yields data batches (X_batch, y_batch). """ new_data, new_targets =, self.targets if self.shuffle: new_data, new_targets = self.shuffle_data(new_data, new_targets) backend = get_backend() for start_idx in range(0, len(, self.batch_size): yield (new_data[start_idx: start_idx+self.batch_size], backend.array( new_targets[start_idx: start_idx+self.batch_size], dtype=backend.float32))
[docs] @staticmethod def shuffle_data(data: list[Any], targets: list[Any]) -> tuple[list[Any], list[Any]]: """Shuffle a given dataset. Parameters ---------- data : list List of data points. targets : list List of labels. Returns ------- Tuple of list and list The shuffled dataset. """ joint_list = list(zip(data, targets)) random.shuffle(joint_list) data_tuple, targets_tuple = zip(*joint_list) return list(data_tuple), list(targets_tuple)