Quantum pipeline using JAX backend

This performs an exact classical simulation.

[1]:
import warnings
warnings.filterwarnings("ignore")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
[2]:
import numpy as np

BATCH_SIZE = 30
LEARNING_RATE = 3e-2
EPOCHS = 120
SEED = 0

Read in the data and create diagrams

[3]:
def read_data(filename):
    labels, sentences = [], []
    with open(filename) as f:
        for line in f:
            t = int(line[0])
            labels.append([t, 1-t])
            sentences.append(line[1:].strip())
    return np.array(labels), sentences


train_labels, train_data = read_data('datasets/mc_train_data.txt')
dev_labels, dev_data = read_data('datasets/mc_dev_data.txt')
test_labels, test_data = read_data('datasets/mc_test_data.txt')

Create diagrams

[4]:
from lambeq import BobcatParser

parser = BobcatParser(verbose='text')

raw_train_diagrams = parser.sentences2diagrams(train_data)
raw_dev_diagrams = parser.sentences2diagrams(dev_data)
raw_test_diagrams = parser.sentences2diagrams(test_data)
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.

Remove the cups

[5]:
from lambeq import RemoveCupsRewriter

remove_cups = RemoveCupsRewriter()

train_diagrams = [remove_cups(diagram) for diagram in raw_train_diagrams]
dev_diagrams = [remove_cups(diagram) for diagram in raw_dev_diagrams]
test_diagrams = [remove_cups(diagram) for diagram in raw_test_diagrams]

train_diagrams[0].draw()
../_images/examples_quantum-pipeline-jax_9_0.png

Create circuits

[6]:
from lambeq import AtomicType, IQPAnsatz

ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 1},
                   n_layers=1, n_single_qubit_params=3)

train_circuits = [ansatz(diagram) for diagram in train_diagrams]
dev_circuits =  [ansatz(diagram) for diagram in dev_diagrams]
test_circuits = [ansatz(diagram) for diagram in test_diagrams]

train_circuits[0].draw(figsize=(9, 9))
../_images/examples_quantum-pipeline-jax_11_0.png

Parameterise

[7]:
from lambeq import NumpyModel

all_circuits = train_circuits + dev_circuits + test_circuits

model = NumpyModel.from_diagrams(all_circuits, use_jit=True)

Define evaluation metric

[8]:
from lambeq import BinaryCrossEntropyLoss

# Using the builtin binary cross-entropy error from lambeq
bce = BinaryCrossEntropyLoss(use_jax=True)

acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2  # half due to double-counting

Initialize trainer

[9]:
from lambeq import QuantumTrainer, SPSAOptimizer

trainer = QuantumTrainer(
    model,
    loss_function=bce,
    epochs=EPOCHS,
    optimizer=SPSAOptimizer,
    optim_hyperparams={'a': 0.2, 'c': 0.06, 'A':0.01*EPOCHS},
    evaluate_functions={'acc': acc},
    evaluate_on_train=True,
    verbose='text',
    seed=0
)
[10]:
from lambeq import Dataset

train_dataset = Dataset(
            train_circuits,
            train_labels,
            batch_size=BATCH_SIZE)

val_dataset = Dataset(dev_circuits, dev_labels, shuffle=False)

Train

[11]:
trainer.fit(train_dataset, val_dataset, log_interval=12)
Epoch 12:   train/loss: 0.6881   valid/loss: 0.6194   train/time: 7.73s   valid/time: 3.04s   train/acc: 0.6429   valid/acc: 0.7000
Epoch 24:   train/loss: 0.5321   valid/loss: 0.5429   train/time: 0.23s   valid/time: 0.05s   train/acc: 0.7143   valid/acc: 0.7333
Epoch 36:   train/loss: 0.4615   valid/loss: 0.4834   train/time: 0.23s   valid/time: 0.05s   train/acc: 0.7714   valid/acc: 0.8000
Epoch 48:   train/loss: 0.2858   valid/loss: 0.4100   train/time: 0.24s   valid/time: 0.05s   train/acc: 0.8429   valid/acc: 0.7667
Epoch 60:   train/loss: 0.1604   valid/loss: 0.3585   train/time: 0.23s   valid/time: 0.05s   train/acc: 0.9143   valid/acc: 0.8333
Epoch 72:   train/loss: 0.2836   valid/loss: 0.3231   train/time: 0.23s   valid/time: 0.05s   train/acc: 0.9429   valid/acc: 0.8333
Epoch 84:   train/loss: 0.3280   valid/loss: 0.3091   train/time: 0.23s   valid/time: 0.05s   train/acc: 0.9143   valid/acc: 0.8333
Epoch 96:   train/loss: 0.2500   valid/loss: 0.2911   train/time: 0.23s   valid/time: 0.05s   train/acc: 0.9429   valid/acc: 0.8333
Epoch 108:  train/loss: 0.1780   valid/loss: 0.3062   train/time: 0.24s   valid/time: 0.05s   train/acc: 0.9286   valid/acc: 0.8333
Epoch 120:  train/loss: 0.0662   valid/loss: 0.2910   train/time: 0.25s   valid/time: 0.05s   train/acc: 0.9429   valid/acc: 0.8333

Training completed!
train/time: 9.83s   train/time_per_epoch: 0.08s   train/time_per_step: 0.03s   valid/time: 3.49s   valid/time_per_eval: 0.03s

Show results

[12]:
import matplotlib.pyplot as plt
import numpy as np

fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
range_ = np.arange(1, trainer.epochs + 1)
ax_tl.plot(range_, trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(range_, trainer.train_eval_results['acc'], color=next(colours))
ax_tr.plot(range_, trainer.val_costs, color=next(colours))
ax_br.plot(range_, trainer.val_eval_results['acc'], color=next(colours))

test_acc = acc(model(test_circuits), np.array(test_labels))
print('Test accuracy:', test_acc)
Test accuracy: 0.96666664
../_images/examples_quantum-pipeline-jax_22_1.png