Coverage for /home/runner/work/tket/tket/pytket/pytket/zx/tensor_eval.py: 97%
211 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:08 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:08 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Collection of methods to evaluate a ZXDiagram to a tensor. This uses the
16numpy tensor features, in particular the einsum evaluation and optimisations."""
18import warnings
19from math import cos, floor, pi, sin, sqrt
20from typing import Any
22import numpy as np
23import sympy
25from pytket.zx import (
26 CliffordGen,
27 DirectedGen,
28 PhasedGen,
29 QuantumType,
30 Rewrite,
31 ZXBox,
32 ZXDiagram,
33 ZXGen,
34 ZXType,
35 ZXVert,
36)
38try:
39 import quimb.tensor as qtn # type: ignore
40except ModuleNotFoundError:
41 warnings.warn(
42 'Missing package for tensor evaluation of ZX diagrams. Run "pip '
43 "install 'pytket[ZX]'\" to install the optional dependencies.",
44 stacklevel=2,
45 )
48def _gen_to_tensor(gen: ZXGen, rank: int) -> np.ndarray:
49 if isinstance(gen, PhasedGen):
50 return _spider_to_tensor(gen, rank)
51 if isinstance(gen, CliffordGen):
52 return _clifford_to_tensor(gen, rank)
53 if isinstance(gen, DirectedGen):
54 return _dir_gen_to_tensor(gen)
55 if isinstance(gen, ZXBox): 55 ↛ 57line 55 didn't jump to line 57 because the condition on line 55 was always true
56 return _tensor_from_basic_diagram(gen.diagram)
57 raise ValueError(f"Cannot convert generator of type {gen.type} to a tensor")
60def _spider_to_tensor(gen: PhasedGen, rank: int) -> np.ndarray:
61 try:
62 if gen.type == ZXType.Hbox:
63 param_c = complex(gen.param)
64 else:
65 param = float(gen.param)
66 except TypeError as e:
67 # If parameter is symbolic, we cannot evaluate the tensor
68 raise ValueError(
69 f"Evaluation of ZXDiagram failed due to symbolic expression {gen.param}"
70 ) from e
71 size = pow(2, rank)
72 if gen.type == ZXType.ZSpider:
73 x = param / 2.0
74 modval = 2.0 * (x - floor(x))
75 phase = np.exp(1j * modval * pi)
76 t = np.zeros(size, dtype=complex)
77 t[0] = 1.0
78 t[size - 1] = phase
79 elif gen.type == ZXType.XSpider:
80 x = param / 2.0
81 modval = 2.0 * (x - floor(x))
82 phase = np.exp(1j * modval * pi)
83 t = np.full(size, 1.0, dtype=complex)
84 constant = pow(sqrt(0.5), rank)
85 for i in range(size):
86 parity = (i).bit_count()
87 t[i] += phase if parity % 2 == 0 else -phase
88 t[i] *= constant
89 elif gen.type == ZXType.Hbox:
90 t = np.full(size, 1.0, dtype=complex)
91 t[size - 1] = param_c
92 elif gen.type == ZXType.XY:
93 x = param / 2.0
94 modval = 2.0 * (x - floor(x))
95 phase = np.exp(-1j * modval * pi)
96 t = np.zeros(size, dtype=complex)
97 t[0] = sqrt(0.5)
98 t[size - 1] = sqrt(0.5) * phase
99 elif gen.type == ZXType.XZ:
100 x = param / 2.0
101 modval = x - floor(x)
102 t = np.zeros(size, dtype=complex)
103 t[0] = cos(modval * pi)
104 t[size - 1] = sin(modval * pi)
105 elif gen.type == ZXType.YZ: 105 ↛ 112line 105 didn't jump to line 112 because the condition on line 105 was always true
106 x = param / 2.0
107 modval = x - floor(x)
108 t = np.zeros(size, dtype=complex)
109 t[0] = cos(modval * pi)
110 t[size - 1] = -1j * sin(modval * pi)
111 else:
112 raise ValueError(
113 f"Cannot convert phased generator of type {gen.type} to a tensor"
114 )
115 return t.reshape(tuple([2] * rank))
118def _clifford_to_tensor(gen: CliffordGen, rank: int) -> np.ndarray:
119 size = pow(2, rank)
120 t = np.zeros(size, dtype=complex)
121 if gen.type == ZXType.PX:
122 t[0] = sqrt(0.5)
123 t[size - 1] = -sqrt(0.5) if gen.param else sqrt(0.5)
124 elif gen.type == ZXType.PY:
125 t[0] = sqrt(0.5)
126 t[size - 1] = 1j * sqrt(0.5) if gen.param else -1j * sqrt(0.5)
127 elif gen.type == ZXType.PZ: 127 ↛ 133line 127 didn't jump to line 133 because the condition on line 127 was always true
128 if gen.param:
129 t[size - 1] = 1.0
130 else:
131 t[0] = 1.0
132 else:
133 raise ValueError(
134 f"Cannot convert Clifford generator of type {gen.type} to a tensor"
135 )
136 return t.reshape(tuple([2] * rank))
139def _dir_gen_to_tensor(gen: DirectedGen) -> np.ndarray:
140 if gen.type == ZXType.Triangle: 140 ↛ 144line 140 didn't jump to line 144 because the condition on line 140 was always true
141 t = np.ones((2, 2), dtype=complex)
142 t[1, 0] = 0.0
143 return t
144 raise ValueError(
145 f"Cannot convert directed generator of type {gen.type} to a tensor"
146 )
149_id_tensor = np.asarray([[1, 0], [0, 1]], dtype=complex)
151_boundary_types = [ZXType.Input, ZXType.Output, ZXType.Open]
154def _tensor_from_basic_diagram(diag: ZXDiagram) -> np.ndarray:
155 try:
156 scalar = complex(diag.scalar)
157 except TypeError as e:
158 raise ValueError(
159 f"Cannot evaluate a diagram with a symbolic scalar. Given scalar: "
160 f"{diag.scalar}"
161 ) from e
162 all_wires = diag.wires
163 indices = dict(zip(all_wires, range(len(all_wires)), strict=False))
164 next_index = len(all_wires)
165 tensor_list: list[Any]
166 tensor_list = []
167 id_wires = set()
168 res_indices = []
169 for b in diag.get_boundary():
170 # Boundaries are handled separately to get the correct order for the
171 # final indices
172 bw = diag.adj_wires(b)[0]
173 bwi = indices[bw]
174 other = diag.other_end(bw, b)
175 if diag.get_zxtype(other) in _boundary_types and bw not in id_wires:
176 # Two boundaries are directly connected, so insert an id tensor for
177 # this boundary
178 id_ind = [bwi, next_index]
179 qt = qtn.Tensor(data=_id_tensor, inds=id_ind)
180 tensor_list.append(qt)
181 res_indices.append(next_index)
182 next_index += 1
183 id_wires.add(bw)
184 else:
185 res_indices.append(bwi)
186 for v in diag.vertices:
187 gen = diag.get_vertex_ZXGen(v)
188 if gen.type in _boundary_types:
189 # Boundaries already handled above
190 continue
191 v_ind = []
192 for w in diag.adj_wires(v):
193 v_ind.append(indices[w])
194 if diag.other_end(w, v) == v:
195 v_ind.append(indices[w])
196 t = _gen_to_tensor(gen, len(v_ind))
197 qt = qtn.Tensor(data=t, inds=v_ind)
198 tensor_list.append(qt)
199 net = qtn.TensorNetwork(tensor_list)
200 net.full_simplify_(seq="ADCR")
201 res_ten = net.contract(output_inds=res_indices, optimize="greedy")
202 result: np.ndarray
203 if isinstance(res_ten, qtn.Tensor): # noqa: SIM108
204 result = res_ten.data
205 else:
206 # Scalar
207 result = np.asarray(res_ten)
208 return result * scalar
211def tensor_from_quantum_diagram(diag: ZXDiagram) -> np.ndarray:
212 """
213 Evaluates a purely quantum :py:class:`ZXDiagram` as a tensor. Indices of
214 the resulting tensor match the order of the boundary vertices from
215 :py:meth:`ZXDiagram.get_boundary`.
217 Throws an exception if the diagram contains any non-quantum vertex or wire,
218 or if it contains any symbolic parameters.
219 """
220 for v in diag.vertices:
221 if diag.get_qtype(v) != QuantumType.Quantum:
222 raise ValueError(
223 "Non-quantum vertex found. tensor_from_quantum_diagram only "
224 "supports diagrams consisting of only quantum components"
225 )
226 for w in diag.wires:
227 if diag.get_wire_qtype(w) != QuantumType.Quantum:
228 raise ValueError(
229 "Non-quantum wire found. tensor_from_quantum_diagram only "
230 "supports diagrams consisting of only quantum components"
231 )
232 diag_copy = ZXDiagram(diag)
233 diag_copy.multiply_scalar(1 / sympy.sqrt(diag.scalar))
234 Rewrite.basic_wires().apply(diag_copy)
235 return _tensor_from_basic_diagram(diag_copy)
238def tensor_from_mixed_diagram(diag: ZXDiagram) -> np.ndarray:
239 """
240 Evaluates an arbitrary :py:class:`ZXDiagram` as a tensor in the doubled
241 picture - that is, each quantum generator is treated as a pair of conjugate
242 generators, whereas a classical generator is just itself.
244 The indices of the resulting tensor match the order of the boundary
245 vertices from :py:meth:`ZXDiagram.get_boundary`, with quantum boundaries
246 split into two. For example, if the boundary is ``[qb1, cb1, qb2]``, the
247 indices will match ``[qb1, qb1_conj, cb1, qb2, qb2_conj]``.
249 Throws an exception if the diagram contains any symbolic parameters.
250 """
251 expanded = diag.to_doubled_diagram()
252 Rewrite.basic_wires().apply(expanded)
253 return _tensor_from_basic_diagram(expanded)
256def _format_tensor_as_unitary(diag: ZXDiagram, tensor: np.ndarray) -> np.ndarray:
257 in_ind = []
258 out_ind = []
259 boundary = diag.get_boundary()
260 for i in range(len(boundary)):
261 if diag.get_zxtype(boundary[i]) == ZXType.Input:
262 in_ind.append(i)
263 else:
264 out_ind.append(i)
265 shape = (pow(2, len(in_ind)), pow(2, len(out_ind)))
266 all_ind = in_ind + out_ind
267 reshaped = np.transpose(tensor, all_ind).reshape(shape)
268 return reshaped.T
271def unitary_from_quantum_diagram(diag: ZXDiagram) -> np.ndarray:
272 """
273 Evaluates a purely quantum :py:class:`ZXDiagram` as a matrix describing the
274 linear map from inputs to outputs. Qubits are indexed according to ILO-BE
275 convention based on relative position amongst inputs/outputs in
276 :py:meth`ZXDiagram.get_boundary`.
278 Throws an exception if the diagram contains any non-quantum vertex or wire,
279 or if it contains any symbolic parameters.
280 """
281 tensor = tensor_from_quantum_diagram(diag)
282 return _format_tensor_as_unitary(diag, tensor)
285def unitary_from_classical_diagram(diag: ZXDiagram) -> np.ndarray:
286 """
287 Evaluates a purely classical :py:class:`ZXDiagram` as a matrix describing
288 the linear map from inputs to outputs. Bits are indexed according to the
289 ILO-BE convention based on relative position amongst inputs/outputs in
290 :py:meth:`ZXDiagram.get_boundary`. Each quantum generator is treated as a
291 pair of conjugate generators.
293 Throws an exception if the diagram contains any non-classical boundary, or
294 if it contains any symbolic parameters.
295 """
296 for b in diag.get_boundary():
297 if diag.get_qtype(b) != QuantumType.Classical:
298 raise ValueError(
299 "Non-classical boundary vertex found. "
300 "unitary_from_classical_diagram only supports diagrams with "
301 "only classical boundaries"
302 )
303 tensor = tensor_from_mixed_diagram(diag)
304 return _format_tensor_as_unitary(diag, tensor)
307def density_matrix_from_cptp_diagram(diag: ZXDiagram) -> np.ndarray:
308 """
309 Evaluates a :py:class:`ZXDiagram` with quantum boundaries but possibly
310 mixed quantum and classical generators as a density matrix. Inputs are
311 treated identically to outputs, i.e. the result is the Choi-state of the
312 diagram. Qubits are indexed according to the ILO-BE convention based on the
313 ordering of boundary vertices in :py:meth:`ZXDiagram.get_boundary`.
315 Throws an exception if the diagram contains any non-quantum boundary, or if
316 it contains any symbolic parameters.
317 """
318 for b in diag.get_boundary():
319 if diag.get_qtype(b) != QuantumType.Quantum:
320 raise ValueError(
321 "Non-quantum boundary vertex found. "
322 "density_matrix_from_cptp_diagram only supports diagrams with "
323 "only quantum boundaries"
324 )
325 tensor = tensor_from_mixed_diagram(diag)
326 n_bounds = len(diag.get_boundary())
327 shape = (pow(2, n_bounds), pow(2, n_bounds))
328 # diag.to_doubled_diagram() in tensor_from_mixed_diagram will alternate
329 # original boundary vertices and their conjugates
330 indices = [2 * i for i in range(n_bounds)] + [2 * i + 1 for i in range(n_bounds)]
331 reshaped = np.transpose(tensor, indices).reshape(shape)
332 return reshaped.T
335def fix_boundaries_to_binary_states(
336 diag: ZXDiagram, vals: dict[ZXVert, int]
337) -> ZXDiagram:
338 """
339 Fixes (a subset of) the boundary vertices of a :py:class:`ZXDiagram` to
340 computational basis states/post-selection.
341 """
342 new_diag = ZXDiagram(diag)
343 b_lookup = dict(zip(diag.get_boundary(), new_diag.get_boundary(), strict=False))
344 for b, val in vals.items():
345 if diag.get_zxtype(b) not in _boundary_types:
346 raise ValueError("Can only set states of boundary vertices")
347 if val not in [0, 1]:
348 raise ValueError("Can only fix boundary states to |0> and |1>.")
349 new_b = b_lookup[b]
350 qtype = diag.get_qtype(b)
351 assert qtype is not None
352 fix_b = new_diag.add_vertex(ZXType.XSpider, float(val), qtype)
353 bw = new_diag.adj_wires(new_b)[0]
354 adj = new_diag.other_end(bw, new_b)
355 adj_p = dict(new_diag.get_wire_ends(bw))[adj]
356 new_diag.add_wire(
357 u=fix_b, v=adj, v_port=adj_p, type=new_diag.get_wire_type(bw), qtype=qtype
358 )
359 new_diag.remove_vertex(new_b)
360 new_diag.multiply_scalar(0.5 if qtype == QuantumType.Quantum else sqrt(0.5))
361 return new_diag
364def fix_inputs_to_binary_state(diag: ZXDiagram, vals: list[int]) -> ZXDiagram:
365 """
366 Fixes all input vertices of a :py:class:`ZXDiagram` to computational basis states.
367 """
368 inputs = diag.get_boundary(type=ZXType.Input)
369 if len(inputs) != len(vals):
370 raise ValueError(
371 f"Gave {len(vals)} values for {len(inputs)} inputs of ZXDiagram"
372 )
373 val_dict = dict(zip(inputs, vals, strict=False))
374 return fix_boundaries_to_binary_states(diag, val_dict)
377def fix_outputs_to_binary_state(diag: ZXDiagram, vals: list[int]) -> ZXDiagram:
378 """
379 Fixes all output vertices of a :py:class:`ZXDiagram` to computational basis
380 states/post-selection.
381 """
382 outputs = diag.get_boundary(type=ZXType.Output)
383 if len(outputs) != len(vals):
384 raise ValueError(
385 f"Gave {len(vals)} values for {len(outputs)} outputs of ZXDiagram"
386 )
387 val_dict = dict(zip(outputs, vals, strict=False))
388 return fix_boundaries_to_binary_states(diag, val_dict)