Quantum pipeline using JAX backend

This performs an exact classical simulation.

[1]:
import warnings
warnings.filterwarnings("ignore")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
[2]:
import numpy as np

BATCH_SIZE = 30
LEARNING_RATE = 3e-2
EPOCHS = 120
SEED = 0

Read in the data and create diagrams

[3]:
def read_data(filename):
    labels, sentences = [], []
    with open(filename) as f:
        for line in f:
            t = int(line[0])
            labels.append([t, 1-t])
            sentences.append(line[1:].strip())
    return np.array(labels), sentences


train_labels, train_data = read_data('datasets/mc_train_data.txt')
dev_labels, dev_data = read_data('datasets/mc_dev_data.txt')
test_labels, test_data = read_data('datasets/mc_test_data.txt')

Create diagrams

[4]:
from lambeq import BobcatParser

parser = BobcatParser(verbose='text')

raw_train_diagrams = parser.sentences2diagrams(train_data)
raw_dev_diagrams = parser.sentences2diagrams(dev_data)
raw_test_diagrams = parser.sentences2diagrams(test_data)
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.

Remove the cups

[5]:
from lambeq import RemoveCupsRewriter

remove_cups = RemoveCupsRewriter()

train_diagrams = [remove_cups(diagram) for diagram in raw_train_diagrams]
dev_diagrams = [remove_cups(diagram) for diagram in raw_dev_diagrams]
test_diagrams = [remove_cups(diagram) for diagram in raw_test_diagrams]

train_diagrams[0].draw()
../_images/examples_quantum-pipeline-jax_9_0.png

Create circuits

[6]:
from lambeq import AtomicType, IQPAnsatz

ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 1},
                   n_layers=1, n_single_qubit_params=3)

train_circuits = [ansatz(diagram) for diagram in train_diagrams]
dev_circuits =  [ansatz(diagram) for diagram in dev_diagrams]
test_circuits = [ansatz(diagram) for diagram in test_diagrams]

train_circuits[0].draw(figsize=(9, 9))
../_images/examples_quantum-pipeline-jax_11_0.png

Parameterise

[7]:
from lambeq import NumpyModel

all_circuits = train_circuits + dev_circuits + test_circuits

model = NumpyModel.from_diagrams(all_circuits, use_jit=True)

Define evaluation metric

[8]:
from lambeq import BinaryCrossEntropyLoss

# Using the builtin binary cross-entropy error from lambeq
bce = BinaryCrossEntropyLoss(use_jax=True)

acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2  # half due to double-counting

Initialize trainer

[9]:
from lambeq import QuantumTrainer, SPSAOptimizer

trainer = QuantumTrainer(
    model,
    loss_function=bce,
    epochs=EPOCHS,
    optimizer=SPSAOptimizer,
    optim_hyperparams={'a': 0.2, 'c': 0.06, 'A':0.01*EPOCHS},
    evaluate_functions={'acc': acc},
    evaluate_on_train=True,
    verbose = 'text',
    seed=0
)
[10]:
from lambeq import Dataset

train_dataset = Dataset(
            train_circuits,
            train_labels,
            batch_size=BATCH_SIZE)

val_dataset = Dataset(dev_circuits, dev_labels, shuffle=False)

Train

[11]:
trainer.fit(train_dataset, val_dataset, log_interval=12)
Epoch 12:   train/loss: 0.6305   valid/loss: 0.7275   train/acc: 0.6143   valid/acc: 0.4667
Epoch 24:   train/loss: 0.8510   valid/loss: 0.6093   train/acc: 0.6571   valid/acc: 0.6667
Epoch 36:   train/loss: 0.6167   valid/loss: 0.7777   train/acc: 0.7000   valid/acc: 0.6333
Epoch 48:   train/loss: 0.2073   valid/loss: 0.7530   train/acc: 0.6857   valid/acc: 0.6333
Epoch 60:   train/loss: 0.2046   valid/loss: 0.6212   train/acc: 0.8143   valid/acc: 0.6000
Epoch 72:   train/loss: 0.1950   valid/loss: 0.7029   train/acc: 0.9429   valid/acc: 0.8000
Epoch 84:   train/loss: 0.3449   valid/loss: 0.5687   train/acc: 0.9571   valid/acc: 0.7667
Epoch 96:   train/loss: 0.2179   valid/loss: 0.6687   train/acc: 0.9429   valid/acc: 0.7667
Epoch 108:  train/loss: 0.2575   valid/loss: 0.6051   train/acc: 0.9429   valid/acc: 0.7667
Epoch 120:  train/loss: 0.1563   valid/loss: 0.5788   train/acc: 0.9714   valid/acc: 0.8000

Training completed!

Show results

[12]:
import matplotlib.pyplot as plt
import numpy as np

fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
range_ = np.arange(1, trainer.epochs + 1)
ax_tl.plot(range_, trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(range_, trainer.train_eval_results['acc'], color=next(colours))
ax_tr.plot(range_, trainer.val_costs, color=next(colours))
ax_br.plot(range_, trainer.val_eval_results['acc'], color=next(colours))

test_acc = acc(model(test_circuits), np.array(test_labels))
print('Test accuracy:', test_acc)
Test accuracy: 0.6666667
../_images/examples_quantum-pipeline-jax_22_1.png