Training: Quantum case

In this tutorial we will train a lambeq model to solve the relative pronoun classification task presented in [Lea2021]. The goal is to predict whether a noun phrase contains a subject-based or an object-based relative clause. The entries of this dataset are extracted from the RelPron dataset [Rea2016].

We will use an IQPAnsatz to convert string diagrams into quantum circuits. The pipeline uses tket as a backend.

If you have already gone through the classical training tutorial, you will see that there are only minor differences for the quantum case.

Download code

Preparation

We start with importing NumPy and specifying some training hyperparameters.

[1]:
import os
import warnings

warnings.filterwarnings('ignore')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

Note

We disable warnings to filter out issues with the tqdm package used in jupyter notebooks. Furthermore, we have to specify whether we want to use parallel computation for the tokenizer used by the BobcatParser. None of the above impairs the performance of the code.

[2]:
import numpy as np

BATCH_SIZE = 10
EPOCHS = 100
SEED = 2

Input data

Let’s read the data and print some example sentences.

[3]:
def read_data(filename):
    labels, sentences = [], []
    with open(filename) as f:
        for line in f:
            t = int(line[0])
            labels.append([t, 1-t])
            sentences.append(line[1:].strip())
    return labels, sentences


train_labels, train_data = read_data('../examples/datasets/rp_train_data.txt')
val_labels, val_data = read_data('../examples/datasets/rp_test_data.txt')
[4]:
train_data[:5]
[4]:
['organization that church establish .',
 'organization that team join .',
 'organization that company sell .',
 'organization that soldier serve .',
 'organization that sailor join .']

Targets are represented as 2-dimensional arrays:

[5]:
train_labels[:5]
[5]:
[[1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]

Creating and parameterising diagrams

The first step is to convert sentences into string diagrams.

Note

We know that the specific dataset only consists of noun phrases, hence, we reduce potential parsing errors by restricting the parser to only return parse trees with the root categories N (noun) and NP (noun phrase).

[6]:
from lambeq import BobcatParser

parser = BobcatParser(root_cats=('NP', 'N'), verbose='text')

raw_train_diagrams = parser.sentences2diagrams(train_data, suppress_exceptions=True)
raw_val_diagrams = parser.sentences2diagrams(val_data, suppress_exceptions=True)
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.

Filter and simplify diagrams

We simplify the diagrams by calling normal_form() and filter out any diagrams that could not be parsed.

[7]:
train_diagrams = [
    diagram.normal_form()
    for diagram in raw_train_diagrams if diagram is not None
]
val_diagrams = [
    diagram.normal_form()
    for diagram in raw_val_diagrams if diagram is not None
]

train_labels = [
    label for (diagram, label)
    in zip(raw_train_diagrams, train_labels)
    if diagram is not None]
val_labels = [
    label for (diagram, label)
    in zip(raw_val_diagrams, val_labels)
    if diagram is not None
]

Let’s see the form of the diagram for a relative clause on the subject of a sentence:

[8]:
train_diagrams[0].draw(figsize=(9, 5), fontsize=12)
../_images/tutorials_trainer-quantum_19_0.png

In object-based relative clauses the noun that follows the relative pronoun is the object of the sentence:

[9]:
train_diagrams[-1].draw(figsize=(9, 5), fontsize=12)
../_images/tutorials_trainer-quantum_21_0.png

Create circuits

In order to run the experiments on a quantum computer, we need to apply to string diagrams a quantum ansatz. For this experiment, we will use an IQPAnsatz, where noun wires (n) are represented by a one-qubit system, and sentence wires (s) are discarded (since we deal with noun phrases).

[10]:
from lambeq import AtomicType, IQPAnsatz, RemoveCupsRewriter

ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 0},
                   n_layers=1, n_single_qubit_params=3)
remove_cups = RemoveCupsRewriter()

train_circuits = [ansatz(remove_cups(diagram)) for diagram in train_diagrams]
val_circuits =  [ansatz(remove_cups(diagram))  for diagram in val_diagrams]

train_circuits[0].draw(figsize=(9, 10))
../_images/tutorials_trainer-quantum_24_0.png

Note that we remove the cups before parameterising the diagrams. By doing so, we reduce the number of post-selections, which makes the model computationally more efficient. The effect of cups removal on a string diagram is demonstrated below:

[11]:
from lambeq.backend import draw_equation

original_diagram = train_diagrams[0]
removed_cups_diagram = remove_cups(original_diagram)

draw_equation(original_diagram, removed_cups_diagram, symbol='-->', figsize=(15, 6), asymmetry=0.3, fontsize=12)
../_images/tutorials_trainer-quantum_26_0.png

Training

Instantiate the model

We will use a TketModel, which we initialise by passing all diagrams to the class method TketModel.from_diagrams(). The TketModel needs a backend configuration dictionary passed as a keyword argument to the initialisation method. This dictionary must contain entries for backend, compilation and shots. The backend is provided by pytket-extensions. In this example, we use Qiskit‘s AerBackend with 8192 shots.

[12]:
from pytket.extensions.qiskit import AerBackend
from lambeq import TketModel

all_circuits = train_circuits + val_circuits

backend = AerBackend()
backend_config = {
    'backend': backend,
    'compilation': backend.default_compilation_pass(2),
    'shots': 8192
}

model = TketModel.from_diagrams(all_circuits, backend_config=backend_config)

Note

The model can also be instantiated by calling TketModel.from_checkpoint(), in case a pre-trained checkpoint is available.

Define loss and evaluation metric

We use standard binary cross-entropy as the loss. Optionally, we can provide a dictionary of callable evaluation metrics with the signature metric(y_hat, y).

[13]:
from lambeq import BinaryCrossEntropyLoss

# Using the builtin binary cross-entropy error from lambeq
bce = BinaryCrossEntropyLoss()

acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2  # half due to double-counting
eval_metrics = {"acc": acc}

Initialise trainer

In lambeq, quantum pipelines are based on the QuantumTrainer class. Furthermore, we will use the standard lambeq SPSA optimizer, implemented in the SPSAOptimizer class. This needs three hyperameters:

  • a: The initial learning rate (decays over time),

  • c: The initial parameter shift scaling factor (decays over time),

  • A: A stability constant, best choice is approx. 0.01 * number of training steps.

[14]:
from lambeq import QuantumTrainer, SPSAOptimizer

trainer = QuantumTrainer(
    model,
    loss_function=bce,
    epochs=EPOCHS,
    optimizer=SPSAOptimizer,
    optim_hyperparams={'a': 0.05, 'c': 0.06, 'A':0.001*EPOCHS},
    evaluate_functions=eval_metrics,
    evaluate_on_train=True,
    verbose = 'text',
    log_dir='RelPron/logs',
    seed=0
)

Create datasets

To facilitate data shuffling and batching, lambeq provides a native Dataset class. Shuffling is enabled by default, and if not specified, the batch size is set to the length of the dataset.

[15]:
from lambeq import Dataset

train_dataset = Dataset(
            train_circuits,
            train_labels,
            batch_size=BATCH_SIZE)

val_dataset = Dataset(val_circuits, val_labels, shuffle=False)

Train

We can now pass the datasets to the fit() method of the trainer to start the training.

[16]:
trainer.fit(train_dataset, val_dataset, early_stopping_interval=10)
Epoch 1:    train/loss: 1.6565   valid/loss: 2.5863   train/acc: 0.6000   valid/acc: 0.6129
Epoch 2:    train/loss: 2.7489   valid/loss: 2.7984   train/acc: 0.6643   valid/acc: 0.3871
Epoch 3:    train/loss: 0.2307   valid/loss: 1.0433   train/acc: 0.7214   valid/acc: 0.7903
Epoch 4:    train/loss: 0.4410   valid/loss: 2.6108   train/acc: 0.7857   valid/acc: 0.6774
Epoch 5:    train/loss: 1.4217   valid/loss: 2.0467   train/acc: 0.7143   valid/acc: 0.5968
Epoch 6:    train/loss: 2.7733   valid/loss: 2.1691   train/acc: 0.6214   valid/acc: 0.4839
Epoch 7:    train/loss: 0.7625   valid/loss: 1.2315   train/acc: 0.6571   valid/acc: 0.7742
Epoch 8:    train/loss: 2.8028   valid/loss: 3.1985   train/acc: 0.5286   valid/acc: 0.6129
Epoch 9:    train/loss: 0.7338   valid/loss: 1.4462   train/acc: 0.4857   valid/acc: 0.5161
Epoch 10:   train/loss: 1.1405   valid/loss: 3.5565   train/acc: 0.3429   valid/acc: 0.3387
Epoch 11:   train/loss: 2.2153   valid/loss: 2.9543   train/acc: 0.4143   valid/acc: 0.2258
Epoch 12:   train/loss: 1.9657   valid/loss: 2.2286   train/acc: 0.3714   valid/acc: 0.3548
Epoch 13:   train/loss: 0.7303   valid/loss: 1.3492   train/acc: 0.5143   valid/acc: 0.5806
Early stopping!
Best model saved to RelPron/logs/best_model.lt

Training completed!

Results

Finally, we visualise the results and evaluate the model on the test data.

[17]:
import matplotlib.pyplot as plt

fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
range_ = np.arange(1, len(trainer.train_epoch_costs)+1)
ax_tl.plot(range_, trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(range_, trainer.train_eval_results['acc'], color=next(colours))
ax_tr.plot(range_, trainer.val_costs, color=next(colours))
ax_br.plot(range_, trainer.val_eval_results['acc'], color=next(colours))

# mark best model as circle
best_epoch = np.argmin(trainer.val_costs)
ax_tl.plot(best_epoch + 1, trainer.train_epoch_costs[best_epoch], 'o', color='black', fillstyle='none')
ax_tr.plot(best_epoch + 1, trainer.val_costs[best_epoch], 'o', color='black', fillstyle='none')
ax_bl.plot(best_epoch + 1, trainer.train_eval_results['acc'][best_epoch], 'o', color='black', fillstyle='none')
ax_br.plot(best_epoch + 1, trainer.val_eval_results['acc'][best_epoch], 'o', color='black', fillstyle='none')

ax_tr.text(best_epoch + 1.4, trainer.val_costs[best_epoch], 'early stopping', va='center')

# print test accuracy
model.load(trainer.log_dir + '/best_model.lt')
test_acc = acc(model(val_circuits), val_labels)
print('Validation accuracy:', test_acc.item())
Validation accuracy: 0.8225806451612904
../_images/tutorials_trainer-quantum_44_1.png

See also: