Quantum pipeline using JAX backend

This performs an exact classical simulation.

[1]:
import warnings
warnings.filterwarnings("ignore")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
[2]:
import numpy as np

BATCH_SIZE = 30
LEARNING_RATE = 3e-2
EPOCHS = 120
SEED = 0

Read in the data and create diagrams

[3]:
def read_data(filename):
    labels, sentences = [], []
    with open(filename) as f:
        for line in f:
            t = int(line[0])
            labels.append([t, 1-t])
            sentences.append(line[1:].strip())
    return np.array(labels), sentences


train_labels, train_data = read_data('datasets/mc_train_data.txt')
dev_labels, dev_data = read_data('datasets/mc_dev_data.txt')
test_labels, test_data = read_data('datasets/mc_test_data.txt')

Create diagrams

[4]:
from lambeq import BobcatParser

parser = BobcatParser(verbose='text')

raw_train_diagrams = parser.sentences2diagrams(train_data)
raw_dev_diagrams = parser.sentences2diagrams(dev_data)
raw_test_diagrams = parser.sentences2diagrams(test_data)
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.

Remove the cups

[5]:
from lambeq import remove_cups

train_diagrams = [remove_cups(diagram) for diagram in raw_train_diagrams]
dev_diagrams = [remove_cups(diagram) for diagram in raw_dev_diagrams]
test_diagrams = [remove_cups(diagram) for diagram in raw_test_diagrams]

train_diagrams[0].draw()
../_images/examples_quantum_pipeline_jax_9_0.png

Create circuits

[6]:
from lambeq import AtomicType, IQPAnsatz

ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 1},
                   n_layers=1, n_single_qubit_params=3)

train_circuits = [ansatz(diagram) for diagram in train_diagrams]
dev_circuits =  [ansatz(diagram) for diagram in dev_diagrams]
test_circuits = [ansatz(diagram) for diagram in test_diagrams]

train_circuits[0].draw(figsize=(9, 12))
../_images/examples_quantum_pipeline_jax_11_0.png

Parameterise

[7]:
from lambeq import NumpyModel

all_circuits = train_circuits + dev_circuits + test_circuits

model = NumpyModel.from_diagrams(all_circuits, use_jit=True)

Define evaluation metric

[8]:
from lambeq import BinaryCrossEntropyLoss

# Using the builtin binary cross-entropy error from lambeq
bce = BinaryCrossEntropyLoss(use_jax=True)

acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2  # half due to double-counting

Initialize trainer

[9]:
from lambeq import QuantumTrainer, SPSAOptimizer

trainer = QuantumTrainer(
    model,
    loss_function=bce,
    epochs=EPOCHS,
    optimizer=SPSAOptimizer,
    optim_hyperparams={'a': 0.2, 'c': 0.06, 'A':0.01*EPOCHS},
    evaluate_functions={'acc': acc},
    evaluate_on_train=True,
    verbose = 'text',
    seed=0
)
[10]:
from lambeq import Dataset

train_dataset = Dataset(
            train_circuits,
            train_labels,
            batch_size=BATCH_SIZE)

val_dataset = Dataset(dev_circuits, dev_labels, shuffle=False)

Train

[11]:
trainer.fit(train_dataset, val_dataset, log_interval=12)
Epoch 12:   train/loss: 0.4061   valid/loss: 0.6106   train/acc: 0.7000   valid/acc: 0.7000
Epoch 24:   train/loss: 0.3656   valid/loss: 0.3698   train/acc: 0.7857   valid/acc: 0.8000
Epoch 36:   train/loss: 0.1651   valid/loss: 0.4405   train/acc: 0.8429   valid/acc: 0.7333
Epoch 48:   train/loss: 0.4870   valid/loss: 0.5849   train/acc: 0.7571   valid/acc: 0.7333
Epoch 60:   train/loss: 0.1931   valid/loss: 0.5443   train/acc: 0.8000   valid/acc: 0.7000
Epoch 72:   train/loss: 0.2040   valid/loss: 0.5040   train/acc: 0.9000   valid/acc: 0.8000
Epoch 84:   train/loss: 0.3860   valid/loss: 0.7005   train/acc: 0.9000   valid/acc: 0.8000
Epoch 96:   train/loss: 0.2307   valid/loss: 0.4480   train/acc: 0.9000   valid/acc: 0.8000
Epoch 108:  train/loss: 0.3077   valid/loss: 0.4548   train/acc: 0.9143   valid/acc: 0.8667
Epoch 120:  train/loss: 0.1187   valid/loss: 0.4960   train/acc: 0.9143   valid/acc: 0.8000

Training completed!

Show results

[12]:
import matplotlib.pyplot as plt
import numpy as np

fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
range_ = np.arange(1, trainer.epochs + 1)
ax_tl.plot(range_, trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(range_, trainer.train_eval_results['acc'], color=next(colours))
ax_tr.plot(range_, trainer.val_costs, color=next(colours))
ax_br.plot(range_, trainer.val_eval_results['acc'], color=next(colours))

test_acc = acc(model(test_circuits), np.array(test_labels))
print('Test accuracy:', test_acc)
Test accuracy: 0.8666667
../_images/examples_quantum_pipeline_jax_22_1.png