Quantum pipeline using JAX backend
This performs an exact classical simulation.
[1]:
import warnings
warnings.filterwarnings("ignore")
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
[2]:
import numpy as np
BATCH_SIZE = 30
LEARNING_RATE = 3e-2
EPOCHS = 120
SEED = 0
Read in the data and create diagrams
[3]:
def read_data(filename):
labels, sentences = [], []
with open(filename) as f:
for line in f:
t = int(line[0])
labels.append([t, 1-t])
sentences.append(line[1:].strip())
return np.array(labels), sentences
train_labels, train_data = read_data('datasets/mc_train_data.txt')
dev_labels, dev_data = read_data('datasets/mc_dev_data.txt')
test_labels, test_data = read_data('datasets/mc_test_data.txt')
Create diagrams
[4]:
from lambeq import BobcatParser
parser = BobcatParser(verbose='text')
raw_train_diagrams = parser.sentences2diagrams(train_data)
raw_dev_diagrams = parser.sentences2diagrams(dev_data)
raw_test_diagrams = parser.sentences2diagrams(test_data)
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Remove the cups
[5]:
from lambeq import RemoveCupsRewriter
remove_cups = RemoveCupsRewriter()
train_diagrams = [remove_cups(diagram) for diagram in raw_train_diagrams]
dev_diagrams = [remove_cups(diagram) for diagram in raw_dev_diagrams]
test_diagrams = [remove_cups(diagram) for diagram in raw_test_diagrams]
train_diagrams[0].draw()
Create circuits
[6]:
from lambeq import AtomicType, IQPAnsatz
ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 1},
n_layers=1, n_single_qubit_params=3)
train_circuits = [ansatz(diagram) for diagram in train_diagrams]
dev_circuits = [ansatz(diagram) for diagram in dev_diagrams]
test_circuits = [ansatz(diagram) for diagram in test_diagrams]
train_circuits[0].draw(figsize=(9, 9))
Parameterise
[7]:
from lambeq import NumpyModel
all_circuits = train_circuits + dev_circuits + test_circuits
model = NumpyModel.from_diagrams(all_circuits, use_jit=True)
Define evaluation metric
[8]:
from lambeq import BinaryCrossEntropyLoss
# Using the builtin binary cross-entropy error from lambeq
bce = BinaryCrossEntropyLoss(use_jax=True)
acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2 # half due to double-counting
Initialize trainer
[9]:
from lambeq import QuantumTrainer, SPSAOptimizer
trainer = QuantumTrainer(
model,
loss_function=bce,
epochs=EPOCHS,
optimizer=SPSAOptimizer,
optim_hyperparams={'a': 0.2, 'c': 0.06, 'A':0.01*EPOCHS},
evaluate_functions={'acc': acc},
evaluate_on_train=True,
verbose = 'text',
seed=0
)
[10]:
from lambeq import Dataset
train_dataset = Dataset(
train_circuits,
train_labels,
batch_size=BATCH_SIZE)
val_dataset = Dataset(dev_circuits, dev_labels, shuffle=False)
Train
[11]:
trainer.fit(train_dataset, val_dataset, log_interval=12)
Epoch 12: train/loss: 0.6305 valid/loss: 0.7275 train/acc: 0.6143 valid/acc: 0.4667
Epoch 24: train/loss: 0.8510 valid/loss: 0.6093 train/acc: 0.6571 valid/acc: 0.6667
Epoch 36: train/loss: 0.6167 valid/loss: 0.7777 train/acc: 0.7000 valid/acc: 0.6333
Epoch 48: train/loss: 0.2073 valid/loss: 0.7530 train/acc: 0.6857 valid/acc: 0.6333
Epoch 60: train/loss: 0.2046 valid/loss: 0.6212 train/acc: 0.8143 valid/acc: 0.6000
Epoch 72: train/loss: 0.1950 valid/loss: 0.7029 train/acc: 0.9429 valid/acc: 0.8000
Epoch 84: train/loss: 0.3449 valid/loss: 0.5687 train/acc: 0.9571 valid/acc: 0.7667
Epoch 96: train/loss: 0.2179 valid/loss: 0.6687 train/acc: 0.9429 valid/acc: 0.7667
Epoch 108: train/loss: 0.2575 valid/loss: 0.6051 train/acc: 0.9429 valid/acc: 0.7667
Epoch 120: train/loss: 0.1563 valid/loss: 0.5788 train/acc: 0.9714 valid/acc: 0.8000
Training completed!
Show results
[12]:
import matplotlib.pyplot as plt
import numpy as np
fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')
colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
range_ = np.arange(1, trainer.epochs + 1)
ax_tl.plot(range_, trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(range_, trainer.train_eval_results['acc'], color=next(colours))
ax_tr.plot(range_, trainer.val_costs, color=next(colours))
ax_br.plot(range_, trainer.val_eval_results['acc'], color=next(colours))
test_acc = acc(model(test_circuits), np.array(test_labels))
print('Test accuracy:', test_acc)
Test accuracy: 0.6666667