Source code for lambeq.backend.converters.tk

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Interface with tket
===================
Module containing the functions to convert from and to tket. This work is
based on DisCoPy (https://discopy.org/) which is released under the
BSD 3-Clause "New" or "Revised" License.

"""
from __future__ import annotations

from typing import cast

import numpy as np
import pytket as tk
from pytket.circuit import (Bit, Command, Op, OpType, Qubit)
from pytket.utils import probs_from_counts
from typing_extensions import Self

from lambeq.backend import Functor
from lambeq.backend.quantum import (bit, Box, Bra, CCX, CCZ, Controlled, CRx,
                                    CRy, CRz, Daggered, Diagram, Discard,
                                    GATES, Id, Ket, Measure, quantum, qubit,
                                    Rx, Ry, Rz, Scalar, Swap, X, Y, Z)

OPTYPE_MAP = {'H': OpType.H,
              'X': OpType.X,
              'Y': OpType.Y,
              'Z': OpType.Z,
              'S': OpType.S,
              'T': OpType.T,
              'Rx': OpType.Rx,
              'Ry': OpType.Ry,
              'Rz': OpType.Rz,
              'CX': OpType.CX,
              'CZ': OpType.CZ,
              'CRx': OpType.CRx,
              'CRy': OpType.CRy,
              'CRz': OpType.CRz,
              'CCX': OpType.CCX,
              'Swap': OpType.SWAP}


[docs]class Circuit(tk.Circuit): """Extend pytket.Circuit with counts post-processing."""
[docs] @staticmethod def upgrade(tk_circuit) -> Circuit: """Takes a :py:class:`pytket.Circuit`, returns a :py:class:`Circuit`. """ result = Circuit(tk_circuit.n_qubits, len(tk_circuit.bits), tk_circuit.post_selection, tk_circuit.scalar, tk_circuit.post_processing) for gate in tk_circuit: name, inputs = gate.op.type.name, gate.op.params + [ x.index[0] for x in gate.qubits + gate.bits] result.__getattribute__(name)(*inputs) return result
[docs] def __init__(self, n_qubits: int = 0, n_bits: int = 0, post_selection: dict[int, int] | None = None, scalar: float | None = None, post_processing: Diagram | None = None) -> None: self.post_selection = post_selection or {} self.scalar = scalar or 1 self.post_processing = ( post_processing or Id(bit ** (n_bits - len(self.post_selection)))) super().__init__(n_qubits, n_bits)
def __repr__(self) -> str: def repr_gate(gate) -> str: name, inputs = gate.op.type.name, gate.op.params + [ x.index[0] for x in gate.qubits + gate.bits] return f'{name}({", ".join(map(str, inputs))})' str_bits = f', {len(self.bits)}' if self.bits else '' init = [f'tk.Circuit({self.n_qubits}{str_bits})'] gates = list(map(repr_gate, list(self))) post_select = ([f'post_select({self.post_selection})'] if self.post_selection else []) scalar = [f'scale({x:.3g})' for x in [self.scalar] if x != 1] post_process = [f'post_process({repr(d)})' for d in [self.post_processing] if d] return '.'.join(init + gates + post_select + scalar + post_process) def __getstate__(self): state = super().__getstate__() state[0].update(self.__dict__) return state def __setstate__(self, state) -> None: for attr in ['scalar', 'post_selection', 'post_processing']: setattr(self, attr, state[0].pop(attr)) super().__setstate__(state) @property def n_bits(self) -> int: """Number of bits in a circuit.""" return len(self.bits)
[docs] def add_bit(self, unit, offset=None) -> None: """Add a bit, update post_processing.""" if offset is not None: self.post_processing @= Id(bit) self.post_processing >>= ( Id(bit ** offset)
[docs] @ Swap(self.post_processing.cod[offset:-1], bit)) super().add_bit(unit)
def rename_units(self, renaming): """Rename units in a circuit.""" bits_to_rename = [ old for old in renaming.keys() if isinstance(old, Bit) and old.index[0] in self.post_selection] post_selection_renaming = { renaming[old].index[0]: self.post_selection[old.index[0]] for old in bits_to_rename} for old in bits_to_rename: del self.post_selection[old.index[0]] self.post_selection.update(post_selection_renaming) super().rename_units(renaming)
[docs] def scale(self, number: float) -> Self: """Scale a circuit by a given number.""" self.scalar *= number return self
[docs] def post_select(self, post_selection: dict[int, int]) -> Self: """Post-select bits on a a given value.""" self.post_selection.update(post_selection) return self
[docs] def post_process(self, process: Diagram) -> Self: """Classical post-processing.""" self.post_processing >>= process return self
[docs] def get_counts(self, *others: Circuit, backend=None, **params) -> list[np.ndarray]: """Runs a circuit on a backend and returns the counts.""" n_shots = params.get('n_shots', 2**10) scale = params.get('scale', True) post_select = params.get('post_select', True) compilation = params.get('compilation', None) normalize = params.get('normalize', True) measure_all = params.get('measure_all', False) seed = params.get('seed', None) if measure_all: for circuit in (self, ) + others: circuit.measure_all() if compilation is not None: for circuit in (self, ) + others: compilation.apply(circuit) handles = backend.process_circuits( (self, ) + others, n_shots=n_shots, seed=seed) counts = [backend.get_result(h).get_counts() for h in handles] if normalize: counts = list(map(probs_from_counts, counts)) if post_select: for i, circuit in enumerate((self, ) + others): post_selected = dict() for bitstring, count in counts[i].items(): if all(bitstring[index] == value for index, value in circuit.post_selection.items()): key = tuple( value for index, value in enumerate(bitstring) if index not in circuit.post_selection) post_selected.update({key: count}) counts[i] = post_selected if scale: for i, circuit in enumerate((self, ) + others): for bitstring in counts[i]: counts[i][bitstring] *= circuit.scalar return counts
[docs]def to_tk(circuit: Diagram): """ Takes a :py:class:`lambeq.quantum.Diagram`, returns a :py:class:`Circuit`. """ # bits and qubits are lists of register indices, at layer i we want # len(bits) == circuit[:i].cod.count(bit) and same for qubits tk_circ = Circuit() bits: list[int] = [] qubits: list[int] = [] circuit = circuit.init_and_discard() def remove_ket1(_, box: Box) -> Diagram | Box: ob_map: dict[Box, Diagram] ob_map = {Ket(1): Ket(0) >> X} # type: ignore[dict-item] return ob_map.get(box, box) def prepare_qubits(qubits: list[int], box: Box, offset: int) -> list[int]: renaming = dict() start = (tk_circ.n_qubits if not qubits else 0 if not offset else qubits[offset - 1] + 1) for i in range(start, tk_circ.n_qubits): old = Qubit('q', i) new = Qubit('q', i + len(box.cod)) renaming.update({old: new}) tk_circ.rename_units(renaming) tk_circ.add_blank_wires(len(box.cod)) return (qubits[:offset] + list(range(start, start + len(box.cod))) + [i + len(box.cod) for i in qubits[offset:]]) def measure_qubits(qubits: list[int], bits: list[int], box: Box, bit_offset: int, qubit_offset: int) -> tuple[list[int], list[int]]: if isinstance(box, Bra): tk_circ.post_select({len(tk_circ.bits): box.bit}) for j, _ in enumerate(box.dom): i_bit, i_qubit = len(tk_circ.bits), qubits[qubit_offset + j] offset = len(bits) if isinstance(box, Measure) else None tk_circ.add_bit(Bit(i_bit), offset=offset) tk_circ.Measure(i_qubit, i_bit) if isinstance(box, Measure): bits = bits[:bit_offset + j] + [i_bit] + bits[bit_offset + j:] # remove measured qubits qubits = (qubits[:qubit_offset] + qubits[qubit_offset + len(box.dom):]) return bits, qubits def swap(i: int, j: int, unit_factory=Qubit) -> None: old, tmp, new = ( unit_factory(i), unit_factory('tmp', 0), unit_factory(j)) tk_circ.rename_units({old: tmp}) tk_circ.rename_units({new: old}) tk_circ.rename_units({tmp: new}) def add_gate(qubits: list[int], box: Box, offset: int) -> None: is_dagger = False if isinstance(box, Daggered): box = box.dagger() is_dagger = True i_qubits = [qubits[offset + j] for j in range(len(box.dom))] if isinstance(box, (Rx, Ry, Rz)): op = Op.create(OPTYPE_MAP[box.name[:2]], 2 * box.phase) elif isinstance(box, Controlled): # The following works only for controls on single qubit gates # reverse the distance order dists = [] curr_box: Box | Controlled = box while isinstance(curr_box, Controlled): dists.append(curr_box.distance) curr_box = curr_box.controlled dists.reverse() # Index of the controlled qubit is the last entry in rel_idx rel_idx = [0] for dist in dists: if dist > 0: # Add control to the left, offset by distance rel_idx = [0] + [i + dist for i in rel_idx] else: # Add control to the right, don't offset right_most_idx = max(rel_idx) rel_idx.insert(-1, right_most_idx - dist) i_qubits = [i_qubits[i] for i in rel_idx] name = box.name.split('(')[0] if box.name in ('CX', 'CZ', 'CCX'): op = Op.create(OPTYPE_MAP[name]) elif name in ('CRx', 'CRz'): # TODO Controlled rotations op = Op.create(OPTYPE_MAP[name], 2 * box.phase) elif name in ('CCX'): op = Op.create(OPTYPE_MAP[name]) elif box.name in OPTYPE_MAP: op = Op.create(OPTYPE_MAP[box.name]) else: raise NotImplementedError(box) if is_dagger: op = op.dagger tk_circ.add_gate(op, i_qubits) circuit = Functor(target_category=quantum, # type: ignore [assignment] ob=lambda _, x: x, ar=remove_ket1)(circuit) # type: ignore [arg-type] for left, box, _ in circuit: if isinstance(box, Ket): qubits = prepare_qubits(qubits, box, left.count(qubit)) elif isinstance(box, (Measure, Bra)): bits, qubits = measure_qubits( qubits, bits, box, left.count(bit), left.count(qubit)) elif isinstance(box, Discard): qubits = (qubits[:left.count(qubit)] + qubits[left.count(qubit) + box.dom.count(qubit):]) elif isinstance(box, Swap): if box == Swap(qubit, qubit): off = left.count(qubit) swap(qubits[off], qubits[off + 1]) elif box == Swap(bit, bit): off = left.count(bit) if tk_circ.post_processing: right = Id(tk_circ.post_processing.cod[off + 2:]) tk_circ.post_process( Id(bit ** off) @ Swap(bit, bit) @ right) else: swap(bits[off], bits[off + 1], unit_factory=Bit) else: # pragma: no cover continue # bits and qubits live in different registers. elif isinstance(box, Scalar): tk_circ.scale(abs(box.array) ** 2) elif isinstance(box, Box): add_gate(qubits, box, left.count(qubit)) else: # pragma: no cover raise NotImplementedError return tk_circ
[docs]def from_tk(tk_circuit: tk.Circuit) -> Diagram: """Translates from tket to a lambeq Diagram.""" tk_circ: Circuit = Circuit.upgrade(tk_circuit) n_qubits = tk_circ.n_qubits def box_and_offset_from_tk(tk_gate) -> tuple[Diagram, int]: name: str = tk_gate.op.type.name offset = tk_gate.args[0].index[0] box: Box | Diagram | None = None if name.endswith('dg'): new_tk_gate = Command(tk_gate.op.dagger, tk_gate.args) undaggered_box, offset = box_and_offset_from_tk(new_tk_gate) box = undaggered_box.dagger() return box.to_diagram(), offset if len(tk_gate.args) == 1: # single qubit gate if name == 'Rx': box = Rx(tk_gate.op.params[0] / 2) elif name == 'Ry': box = Ry(tk_gate.op.params[0] / 2) elif name == 'Rz': box = Rz(tk_gate.op.params[0] / 2) elif name in GATES: box = cast(Box, GATES[name]) if len(tk_gate.args) == 2: # two qubit gate distance = tk_gate.args[1].index[0] - tk_gate.args[0].index[0] offset = tk_gate.args[0].index[0] if distance < 0: offset += distance if name == 'CRx': box = CRx(tk_gate.op.params[0] / 2, distance) elif name == 'CRy': box = CRy(tk_gate.op.params[0] / 2, distance) elif name == 'CRz': box = CRz(tk_gate.op.params[0] / 2, distance) elif name == 'SWAP': distance = abs(distance) idx = list(range(distance + 1)) idx[0], idx[-1] = idx[-1], idx[0] box = Diagram.permutation(qubit ** (distance + 1), idx) elif name == 'CX': box = Controlled(X, distance) elif name == 'CY': box = Controlled(Y, distance) elif name == 'CZ': box = Controlled(Z, distance) if len(tk_gate.args) == 3: # three qubit gate controls = (tk_gate.args[0].index[0], tk_gate.args[1].index[0]) target = tk_gate.args[2].index[0] span = max(controls + (target,)) - min(controls + (target,)) + 1 if name == 'CCX': box = Id(qubit**span).apply_gate(CCX, *controls, target) elif name == 'CCZ': box = Id(qubit**span).apply_gate(CCZ, *controls, target) offset = min(controls + (target,)) if box is None: raise NotImplementedError else: return box.to_diagram(), offset # type: ignore [return-value] circuit = Ket(*(0, ) * n_qubits).to_diagram() bras = {} for tk_gate in tk_circ.get_commands(): if tk_gate.op.type.name == 'Measure': offset: int = tk_gate.qubits[0].index[0] bit_index: int = tk_gate.bits[0].index[0] if bit_index in tk_circ.post_selection: bras[offset] = tk_circ.post_selection[bit_index] continue # post selection happens at the end left = circuit.cod[:offset] right = circuit.cod[offset + 1:] circuit = circuit >> left @ Measure() @ right else: box, offset = box_and_offset_from_tk(tk_gate) left = circuit.cod[:offset] right = circuit.cod[offset + len(box.dom):] circuit = circuit >> left @ box @ right circuit = circuit >> Id().tensor(*( # type: ignore[arg-type] Bra(bras[i]) if i in bras else Discard() if x == qubit else Id(bit) for i, x in enumerate(circuit.cod))) if tk_circ.scalar != 1: circuit = circuit @ Scalar(np.sqrt(abs(tk_circ.scalar))) return circuit >> tk_circ.post_processing # type: ignore [return-value]