Coverage for /home/runner/work/tket/tket/pytket/pytket/zx/tensor_eval.py: 97%

211 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 15:08 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Collection of methods to evaluate a ZXDiagram to a tensor. This uses the 

16numpy tensor features, in particular the einsum evaluation and optimisations.""" 

17 

18import warnings 

19from math import cos, floor, pi, sin, sqrt 

20from typing import Any 

21 

22import numpy as np 

23import sympy 

24 

25from pytket.zx import ( 

26 CliffordGen, 

27 DirectedGen, 

28 PhasedGen, 

29 QuantumType, 

30 Rewrite, 

31 ZXBox, 

32 ZXDiagram, 

33 ZXGen, 

34 ZXType, 

35 ZXVert, 

36) 

37 

38try: 

39 import quimb.tensor as qtn # type: ignore 

40except ModuleNotFoundError: 

41 warnings.warn( 

42 'Missing package for tensor evaluation of ZX diagrams. Run "pip ' 

43 "install 'pytket[ZX]'\" to install the optional dependencies.", 

44 stacklevel=2, 

45 ) 

46 

47 

48def _gen_to_tensor(gen: ZXGen, rank: int) -> np.ndarray: 

49 if isinstance(gen, PhasedGen): 

50 return _spider_to_tensor(gen, rank) 

51 if isinstance(gen, CliffordGen): 

52 return _clifford_to_tensor(gen, rank) 

53 if isinstance(gen, DirectedGen): 

54 return _dir_gen_to_tensor(gen) 

55 if isinstance(gen, ZXBox): 55 ↛ 57line 55 didn't jump to line 57 because the condition on line 55 was always true

56 return _tensor_from_basic_diagram(gen.diagram) 

57 raise ValueError(f"Cannot convert generator of type {gen.type} to a tensor") 

58 

59 

60def _spider_to_tensor(gen: PhasedGen, rank: int) -> np.ndarray: 

61 try: 

62 if gen.type == ZXType.Hbox: 

63 param_c = complex(gen.param) 

64 else: 

65 param = float(gen.param) 

66 except TypeError as e: 

67 # If parameter is symbolic, we cannot evaluate the tensor 

68 raise ValueError( 

69 f"Evaluation of ZXDiagram failed due to symbolic expression {gen.param}" 

70 ) from e 

71 size = pow(2, rank) 

72 if gen.type == ZXType.ZSpider: 

73 x = param / 2.0 

74 modval = 2.0 * (x - floor(x)) 

75 phase = np.exp(1j * modval * pi) 

76 t = np.zeros(size, dtype=complex) 

77 t[0] = 1.0 

78 t[size - 1] = phase 

79 elif gen.type == ZXType.XSpider: 

80 x = param / 2.0 

81 modval = 2.0 * (x - floor(x)) 

82 phase = np.exp(1j * modval * pi) 

83 t = np.full(size, 1.0, dtype=complex) 

84 constant = pow(sqrt(0.5), rank) 

85 for i in range(size): 

86 parity = (i).bit_count() 

87 t[i] += phase if parity % 2 == 0 else -phase 

88 t[i] *= constant 

89 elif gen.type == ZXType.Hbox: 

90 t = np.full(size, 1.0, dtype=complex) 

91 t[size - 1] = param_c 

92 elif gen.type == ZXType.XY: 

93 x = param / 2.0 

94 modval = 2.0 * (x - floor(x)) 

95 phase = np.exp(-1j * modval * pi) 

96 t = np.zeros(size, dtype=complex) 

97 t[0] = sqrt(0.5) 

98 t[size - 1] = sqrt(0.5) * phase 

99 elif gen.type == ZXType.XZ: 

100 x = param / 2.0 

101 modval = x - floor(x) 

102 t = np.zeros(size, dtype=complex) 

103 t[0] = cos(modval * pi) 

104 t[size - 1] = sin(modval * pi) 

105 elif gen.type == ZXType.YZ: 105 ↛ 112line 105 didn't jump to line 112 because the condition on line 105 was always true

106 x = param / 2.0 

107 modval = x - floor(x) 

108 t = np.zeros(size, dtype=complex) 

109 t[0] = cos(modval * pi) 

110 t[size - 1] = -1j * sin(modval * pi) 

111 else: 

112 raise ValueError( 

113 f"Cannot convert phased generator of type {gen.type} to a tensor" 

114 ) 

115 return t.reshape(tuple([2] * rank)) 

116 

117 

118def _clifford_to_tensor(gen: CliffordGen, rank: int) -> np.ndarray: 

119 size = pow(2, rank) 

120 t = np.zeros(size, dtype=complex) 

121 if gen.type == ZXType.PX: 

122 t[0] = sqrt(0.5) 

123 t[size - 1] = -sqrt(0.5) if gen.param else sqrt(0.5) 

124 elif gen.type == ZXType.PY: 

125 t[0] = sqrt(0.5) 

126 t[size - 1] = 1j * sqrt(0.5) if gen.param else -1j * sqrt(0.5) 

127 elif gen.type == ZXType.PZ: 127 ↛ 133line 127 didn't jump to line 133 because the condition on line 127 was always true

128 if gen.param: 

129 t[size - 1] = 1.0 

130 else: 

131 t[0] = 1.0 

132 else: 

133 raise ValueError( 

134 f"Cannot convert Clifford generator of type {gen.type} to a tensor" 

135 ) 

136 return t.reshape(tuple([2] * rank)) 

137 

138 

139def _dir_gen_to_tensor(gen: DirectedGen) -> np.ndarray: 

140 if gen.type == ZXType.Triangle: 140 ↛ 144line 140 didn't jump to line 144 because the condition on line 140 was always true

141 t = np.ones((2, 2), dtype=complex) 

142 t[1, 0] = 0.0 

143 return t 

144 raise ValueError( 

145 f"Cannot convert directed generator of type {gen.type} to a tensor" 

146 ) 

147 

148 

149_id_tensor = np.asarray([[1, 0], [0, 1]], dtype=complex) 

150 

151_boundary_types = [ZXType.Input, ZXType.Output, ZXType.Open] 

152 

153 

154def _tensor_from_basic_diagram(diag: ZXDiagram) -> np.ndarray: 

155 try: 

156 scalar = complex(diag.scalar) 

157 except TypeError as e: 

158 raise ValueError( 

159 f"Cannot evaluate a diagram with a symbolic scalar. Given scalar: " 

160 f"{diag.scalar}" 

161 ) from e 

162 all_wires = diag.wires 

163 indices = dict(zip(all_wires, range(len(all_wires)), strict=False)) 

164 next_index = len(all_wires) 

165 tensor_list: list[Any] 

166 tensor_list = [] 

167 id_wires = set() 

168 res_indices = [] 

169 for b in diag.get_boundary(): 

170 # Boundaries are handled separately to get the correct order for the 

171 # final indices 

172 bw = diag.adj_wires(b)[0] 

173 bwi = indices[bw] 

174 other = diag.other_end(bw, b) 

175 if diag.get_zxtype(other) in _boundary_types and bw not in id_wires: 

176 # Two boundaries are directly connected, so insert an id tensor for 

177 # this boundary 

178 id_ind = [bwi, next_index] 

179 qt = qtn.Tensor(data=_id_tensor, inds=id_ind) 

180 tensor_list.append(qt) 

181 res_indices.append(next_index) 

182 next_index += 1 

183 id_wires.add(bw) 

184 else: 

185 res_indices.append(bwi) 

186 for v in diag.vertices: 

187 gen = diag.get_vertex_ZXGen(v) 

188 if gen.type in _boundary_types: 

189 # Boundaries already handled above 

190 continue 

191 v_ind = [] 

192 for w in diag.adj_wires(v): 

193 v_ind.append(indices[w]) 

194 if diag.other_end(w, v) == v: 

195 v_ind.append(indices[w]) 

196 t = _gen_to_tensor(gen, len(v_ind)) 

197 qt = qtn.Tensor(data=t, inds=v_ind) 

198 tensor_list.append(qt) 

199 net = qtn.TensorNetwork(tensor_list) 

200 net.full_simplify_(seq="ADCR") 

201 res_ten = net.contract(output_inds=res_indices, optimize="greedy") 

202 result: np.ndarray 

203 if isinstance(res_ten, qtn.Tensor): # noqa: SIM108 

204 result = res_ten.data 

205 else: 

206 # Scalar 

207 result = np.asarray(res_ten) 

208 return result * scalar 

209 

210 

211def tensor_from_quantum_diagram(diag: ZXDiagram) -> np.ndarray: 

212 """ 

213 Evaluates a purely quantum :py:class:`ZXDiagram` as a tensor. Indices of 

214 the resulting tensor match the order of the boundary vertices from 

215 :py:meth:`ZXDiagram.get_boundary`. 

216 

217 Throws an exception if the diagram contains any non-quantum vertex or wire, 

218 or if it contains any symbolic parameters. 

219 """ 

220 for v in diag.vertices: 

221 if diag.get_qtype(v) != QuantumType.Quantum: 

222 raise ValueError( 

223 "Non-quantum vertex found. tensor_from_quantum_diagram only " 

224 "supports diagrams consisting of only quantum components" 

225 ) 

226 for w in diag.wires: 

227 if diag.get_wire_qtype(w) != QuantumType.Quantum: 

228 raise ValueError( 

229 "Non-quantum wire found. tensor_from_quantum_diagram only " 

230 "supports diagrams consisting of only quantum components" 

231 ) 

232 diag_copy = ZXDiagram(diag) 

233 diag_copy.multiply_scalar(1 / sympy.sqrt(diag.scalar)) 

234 Rewrite.basic_wires().apply(diag_copy) 

235 return _tensor_from_basic_diagram(diag_copy) 

236 

237 

238def tensor_from_mixed_diagram(diag: ZXDiagram) -> np.ndarray: 

239 """ 

240 Evaluates an arbitrary :py:class:`ZXDiagram` as a tensor in the doubled 

241 picture - that is, each quantum generator is treated as a pair of conjugate 

242 generators, whereas a classical generator is just itself. 

243 

244 The indices of the resulting tensor match the order of the boundary 

245 vertices from :py:meth:`ZXDiagram.get_boundary`, with quantum boundaries 

246 split into two. For example, if the boundary is ``[qb1, cb1, qb2]``, the 

247 indices will match ``[qb1, qb1_conj, cb1, qb2, qb2_conj]``. 

248 

249 Throws an exception if the diagram contains any symbolic parameters. 

250 """ 

251 expanded = diag.to_doubled_diagram() 

252 Rewrite.basic_wires().apply(expanded) 

253 return _tensor_from_basic_diagram(expanded) 

254 

255 

256def _format_tensor_as_unitary(diag: ZXDiagram, tensor: np.ndarray) -> np.ndarray: 

257 in_ind = [] 

258 out_ind = [] 

259 boundary = diag.get_boundary() 

260 for i in range(len(boundary)): 

261 if diag.get_zxtype(boundary[i]) == ZXType.Input: 

262 in_ind.append(i) 

263 else: 

264 out_ind.append(i) 

265 shape = (pow(2, len(in_ind)), pow(2, len(out_ind))) 

266 all_ind = in_ind + out_ind 

267 reshaped = np.transpose(tensor, all_ind).reshape(shape) 

268 return reshaped.T 

269 

270 

271def unitary_from_quantum_diagram(diag: ZXDiagram) -> np.ndarray: 

272 """ 

273 Evaluates a purely quantum :py:class:`ZXDiagram` as a matrix describing the 

274 linear map from inputs to outputs. Qubits are indexed according to ILO-BE 

275 convention based on relative position amongst inputs/outputs in 

276 :py:meth`ZXDiagram.get_boundary`. 

277 

278 Throws an exception if the diagram contains any non-quantum vertex or wire, 

279 or if it contains any symbolic parameters. 

280 """ 

281 tensor = tensor_from_quantum_diagram(diag) 

282 return _format_tensor_as_unitary(diag, tensor) 

283 

284 

285def unitary_from_classical_diagram(diag: ZXDiagram) -> np.ndarray: 

286 """ 

287 Evaluates a purely classical :py:class:`ZXDiagram` as a matrix describing 

288 the linear map from inputs to outputs. Bits are indexed according to the 

289 ILO-BE convention based on relative position amongst inputs/outputs in 

290 :py:meth:`ZXDiagram.get_boundary`. Each quantum generator is treated as a 

291 pair of conjugate generators. 

292 

293 Throws an exception if the diagram contains any non-classical boundary, or 

294 if it contains any symbolic parameters. 

295 """ 

296 for b in diag.get_boundary(): 

297 if diag.get_qtype(b) != QuantumType.Classical: 

298 raise ValueError( 

299 "Non-classical boundary vertex found. " 

300 "unitary_from_classical_diagram only supports diagrams with " 

301 "only classical boundaries" 

302 ) 

303 tensor = tensor_from_mixed_diagram(diag) 

304 return _format_tensor_as_unitary(diag, tensor) 

305 

306 

307def density_matrix_from_cptp_diagram(diag: ZXDiagram) -> np.ndarray: 

308 """ 

309 Evaluates a :py:class:`ZXDiagram` with quantum boundaries but possibly 

310 mixed quantum and classical generators as a density matrix. Inputs are 

311 treated identically to outputs, i.e. the result is the Choi-state of the 

312 diagram. Qubits are indexed according to the ILO-BE convention based on the 

313 ordering of boundary vertices in :py:meth:`ZXDiagram.get_boundary`. 

314 

315 Throws an exception if the diagram contains any non-quantum boundary, or if 

316 it contains any symbolic parameters. 

317 """ 

318 for b in diag.get_boundary(): 

319 if diag.get_qtype(b) != QuantumType.Quantum: 

320 raise ValueError( 

321 "Non-quantum boundary vertex found. " 

322 "density_matrix_from_cptp_diagram only supports diagrams with " 

323 "only quantum boundaries" 

324 ) 

325 tensor = tensor_from_mixed_diagram(diag) 

326 n_bounds = len(diag.get_boundary()) 

327 shape = (pow(2, n_bounds), pow(2, n_bounds)) 

328 # diag.to_doubled_diagram() in tensor_from_mixed_diagram will alternate 

329 # original boundary vertices and their conjugates 

330 indices = [2 * i for i in range(n_bounds)] + [2 * i + 1 for i in range(n_bounds)] 

331 reshaped = np.transpose(tensor, indices).reshape(shape) 

332 return reshaped.T 

333 

334 

335def fix_boundaries_to_binary_states( 

336 diag: ZXDiagram, vals: dict[ZXVert, int] 

337) -> ZXDiagram: 

338 """ 

339 Fixes (a subset of) the boundary vertices of a :py:class:`ZXDiagram` to 

340 computational basis states/post-selection. 

341 """ 

342 new_diag = ZXDiagram(diag) 

343 b_lookup = dict(zip(diag.get_boundary(), new_diag.get_boundary(), strict=False)) 

344 for b, val in vals.items(): 

345 if diag.get_zxtype(b) not in _boundary_types: 

346 raise ValueError("Can only set states of boundary vertices") 

347 if val not in [0, 1]: 

348 raise ValueError("Can only fix boundary states to |0> and |1>.") 

349 new_b = b_lookup[b] 

350 qtype = diag.get_qtype(b) 

351 assert qtype is not None 

352 fix_b = new_diag.add_vertex(ZXType.XSpider, float(val), qtype) 

353 bw = new_diag.adj_wires(new_b)[0] 

354 adj = new_diag.other_end(bw, new_b) 

355 adj_p = dict(new_diag.get_wire_ends(bw))[adj] 

356 new_diag.add_wire( 

357 u=fix_b, v=adj, v_port=adj_p, type=new_diag.get_wire_type(bw), qtype=qtype 

358 ) 

359 new_diag.remove_vertex(new_b) 

360 new_diag.multiply_scalar(0.5 if qtype == QuantumType.Quantum else sqrt(0.5)) 

361 return new_diag 

362 

363 

364def fix_inputs_to_binary_state(diag: ZXDiagram, vals: list[int]) -> ZXDiagram: 

365 """ 

366 Fixes all input vertices of a :py:class:`ZXDiagram` to computational basis states. 

367 """ 

368 inputs = diag.get_boundary(type=ZXType.Input) 

369 if len(inputs) != len(vals): 

370 raise ValueError( 

371 f"Gave {len(vals)} values for {len(inputs)} inputs of ZXDiagram" 

372 ) 

373 val_dict = dict(zip(inputs, vals, strict=False)) 

374 return fix_boundaries_to_binary_states(diag, val_dict) 

375 

376 

377def fix_outputs_to_binary_state(diag: ZXDiagram, vals: list[int]) -> ZXDiagram: 

378 """ 

379 Fixes all output vertices of a :py:class:`ZXDiagram` to computational basis 

380 states/post-selection. 

381 """ 

382 outputs = diag.get_boundary(type=ZXType.Output) 

383 if len(outputs) != len(vals): 

384 raise ValueError( 

385 f"Gave {len(vals)} values for {len(outputs)} outputs of ZXDiagram" 

386 ) 

387 val_dict = dict(zip(outputs, vals, strict=False)) 

388 return fix_boundaries_to_binary_states(diag, val_dict)