Coverage for /home/runner/work/tket/tket/pytket/pytket/zx/tensor_eval.py: 97%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 11:30 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Collection of methods to evaluate a ZXDiagram to a tensor. This uses the 

16numpy tensor features, in particular the einsum evaluation and optimisations.""" 

17import warnings 

18from math import cos, floor, pi, sin, sqrt 

19from typing import Any 

20 

21import numpy as np 

22import sympy 

23 

24from pytket.zx import ( 

25 CliffordGen, 

26 DirectedGen, 

27 PhasedGen, 

28 QuantumType, 

29 Rewrite, 

30 ZXBox, 

31 ZXDiagram, 

32 ZXGen, 

33 ZXType, 

34 ZXVert, 

35) 

36 

37try: 

38 import quimb.tensor as qtn # type: ignore 

39except ModuleNotFoundError: 

40 warnings.warn( 

41 'Missing package for tensor evaluation of ZX diagrams. Run "pip ' 

42 "install 'pytket[ZX]'\" to install the optional dependencies." 

43 ) 

44 

45 

46def _gen_to_tensor(gen: ZXGen, rank: int) -> np.ndarray: 

47 if isinstance(gen, PhasedGen): 

48 return _spider_to_tensor(gen, rank) 

49 if isinstance(gen, CliffordGen): 

50 return _clifford_to_tensor(gen, rank) 

51 if isinstance(gen, DirectedGen): 

52 return _dir_gen_to_tensor(gen) 

53 if isinstance(gen, ZXBox): 53 ↛ 55line 53 didn't jump to line 55 because the condition on line 53 was always true

54 return _tensor_from_basic_diagram(gen.diagram) 

55 raise ValueError(f"Cannot convert generator of type {gen.type} to a tensor") 

56 

57 

58def _spider_to_tensor(gen: PhasedGen, rank: int) -> np.ndarray: 

59 try: 

60 if gen.type == ZXType.Hbox: 

61 param_c = complex(gen.param) 

62 else: 

63 param = float(gen.param) 

64 except TypeError as e: 

65 # If parameter is symbolic, we cannot evaluate the tensor 

66 raise ValueError( 

67 f"Evaluation of ZXDiagram failed due to symbolic expression {gen.param}" 

68 ) from e 

69 size = pow(2, rank) 

70 if gen.type == ZXType.ZSpider: 

71 x = param / 2.0 

72 modval = 2.0 * (x - floor(x)) 

73 phase = np.exp(1j * modval * pi) 

74 t = np.zeros(size, dtype=complex) 

75 t[0] = 1.0 

76 t[size - 1] = phase 

77 elif gen.type == ZXType.XSpider: 

78 x = param / 2.0 

79 modval = 2.0 * (x - floor(x)) 

80 phase = np.exp(1j * modval * pi) 

81 t = np.full(size, 1.0, dtype=complex) 

82 constant = pow(sqrt(0.5), rank) 

83 for i in range(size): 

84 parity = (i).bit_count() 

85 t[i] += phase if parity % 2 == 0 else -phase 

86 t[i] *= constant 

87 elif gen.type == ZXType.Hbox: 

88 t = np.full(size, 1.0, dtype=complex) 

89 t[size - 1] = param_c 

90 elif gen.type == ZXType.XY: 

91 x = param / 2.0 

92 modval = 2.0 * (x - floor(x)) 

93 phase = np.exp(-1j * modval * pi) 

94 t = np.zeros(size, dtype=complex) 

95 t[0] = sqrt(0.5) 

96 t[size - 1] = sqrt(0.5) * phase 

97 elif gen.type == ZXType.XZ: 

98 x = param / 2.0 

99 modval = x - floor(x) 

100 t = np.zeros(size, dtype=complex) 

101 t[0] = cos(modval * pi) 

102 t[size - 1] = sin(modval * pi) 

103 elif gen.type == ZXType.YZ: 103 ↛ 110line 103 didn't jump to line 110 because the condition on line 103 was always true

104 x = param / 2.0 

105 modval = x - floor(x) 

106 t = np.zeros(size, dtype=complex) 

107 t[0] = cos(modval * pi) 

108 t[size - 1] = -1j * sin(modval * pi) 

109 else: 

110 raise ValueError( 

111 f"Cannot convert phased generator of type {gen.type} to a tensor" 

112 ) 

113 return t.reshape(tuple([2] * rank)) 

114 

115 

116def _clifford_to_tensor(gen: CliffordGen, rank: int) -> np.ndarray: 

117 size = pow(2, rank) 

118 t = np.zeros(size, dtype=complex) 

119 if gen.type == ZXType.PX: 

120 t[0] = sqrt(0.5) 

121 t[size - 1] = -sqrt(0.5) if gen.param else sqrt(0.5) 

122 elif gen.type == ZXType.PY: 

123 t[0] = sqrt(0.5) 

124 t[size - 1] = 1j * sqrt(0.5) if gen.param else -1j * sqrt(0.5) 

125 elif gen.type == ZXType.PZ: 125 ↛ 131line 125 didn't jump to line 131 because the condition on line 125 was always true

126 if gen.param: 

127 t[size - 1] = 1.0 

128 else: 

129 t[0] = 1.0 

130 else: 

131 raise ValueError( 

132 f"Cannot convert Clifford generator of type {gen.type} to a tensor" 

133 ) 

134 return t.reshape(tuple([2] * rank)) 

135 

136 

137def _dir_gen_to_tensor(gen: DirectedGen) -> np.ndarray: 

138 if gen.type == ZXType.Triangle: 138 ↛ 142line 138 didn't jump to line 142 because the condition on line 138 was always true

139 t = np.ones((2, 2), dtype=complex) 

140 t[1, 0] = 0.0 

141 return t 

142 raise ValueError( 

143 f"Cannot convert directed generator of type {gen.type} to a tensor" 

144 ) 

145 

146 

147_id_tensor = np.asarray([[1, 0], [0, 1]], dtype=complex) 

148 

149_boundary_types = [ZXType.Input, ZXType.Output, ZXType.Open] 

150 

151 

152def _tensor_from_basic_diagram(diag: ZXDiagram) -> np.ndarray: 

153 try: 

154 scalar = complex(diag.scalar) 

155 except TypeError as e: 

156 raise ValueError( 

157 f"Cannot evaluate a diagram with a symbolic scalar. Given scalar: " 

158 f"{diag.scalar}" 

159 ) from e 

160 all_wires = diag.wires 

161 indices = dict(zip(all_wires, range(len(all_wires)))) 

162 next_index = len(all_wires) 

163 tensor_list: list[Any] 

164 tensor_list = [] 

165 id_wires = set() 

166 res_indices = [] 

167 for b in diag.get_boundary(): 

168 # Boundaries are handled separately to get the correct order for the 

169 # final indices 

170 bw = diag.adj_wires(b)[0] 

171 bwi = indices[bw] 

172 other = diag.other_end(bw, b) 

173 if diag.get_zxtype(other) in _boundary_types and bw not in id_wires: 

174 # Two boundaries are directly connected, so insert an id tensor for 

175 # this boundary 

176 id_ind = [bwi, next_index] 

177 qt = qtn.Tensor(data=_id_tensor, inds=id_ind) 

178 tensor_list.append(qt) 

179 res_indices.append(next_index) 

180 next_index += 1 

181 id_wires.add(bw) 

182 else: 

183 res_indices.append(bwi) 

184 for v in diag.vertices: 

185 gen = diag.get_vertex_ZXGen(v) 

186 if gen.type in _boundary_types: 

187 # Boundaries already handled above 

188 continue 

189 v_ind = [] 

190 for w in diag.adj_wires(v): 

191 v_ind.append(indices[w]) 

192 if diag.other_end(w, v) == v: 

193 v_ind.append(indices[w]) 

194 t = _gen_to_tensor(gen, len(v_ind)) 

195 qt = qtn.Tensor(data=t, inds=v_ind) 

196 tensor_list.append(qt) 

197 net = qtn.TensorNetwork(tensor_list) 

198 net.full_simplify_(seq="ADCR") 

199 res_ten = net.contract(output_inds=res_indices, optimize="greedy") 

200 result: np.ndarray 

201 if isinstance(res_ten, qtn.Tensor): 

202 result = res_ten.data 

203 else: 

204 # Scalar 

205 result = np.asarray(res_ten) 

206 return result * scalar 

207 

208 

209def tensor_from_quantum_diagram(diag: ZXDiagram) -> np.ndarray: 

210 for v in diag.vertices: 

211 if diag.get_qtype(v) != QuantumType.Quantum: 

212 raise ValueError( 

213 "Non-quantum vertex found. evaluate_quantum_diagram only " 

214 "supports diagrams consisting of only quantum components" 

215 ) 

216 for w in diag.wires: 

217 if diag.get_wire_qtype(w) != QuantumType.Quantum: 

218 raise ValueError( 

219 "Non-quantum wire found. evaluate_quantum_diagram only " 

220 "supports diagrams consisting of only quantum components" 

221 ) 

222 diag_copy = ZXDiagram(diag) 

223 diag_copy.multiply_scalar(1 / sympy.sqrt(diag.scalar)) 

224 Rewrite.basic_wires().apply(diag_copy) 

225 return _tensor_from_basic_diagram(diag_copy) 

226 

227 

228def tensor_from_mixed_diagram(diag: ZXDiagram) -> np.ndarray: 

229 expanded = diag.to_doubled_diagram() 

230 Rewrite.basic_wires().apply(expanded) 

231 return _tensor_from_basic_diagram(expanded) 

232 

233 

234def _format_tensor_as_unitary(diag: ZXDiagram, tensor: np.ndarray) -> np.ndarray: 

235 in_ind = [] 

236 out_ind = [] 

237 boundary = diag.get_boundary() 

238 for i in range(len(boundary)): 

239 if diag.get_zxtype(boundary[i]) == ZXType.Input: 

240 in_ind.append(i) 

241 else: 

242 out_ind.append(i) 

243 shape = (pow(2, len(in_ind)), pow(2, len(out_ind))) 

244 all_ind = in_ind + out_ind 

245 reshaped = np.transpose(tensor, all_ind).reshape(shape) 

246 return reshaped.T 

247 

248 

249def unitary_from_quantum_diagram(diag: ZXDiagram) -> np.ndarray: 

250 tensor = tensor_from_quantum_diagram(diag) 

251 return _format_tensor_as_unitary(diag, tensor) 

252 

253 

254def unitary_from_classical_diagram(diag: ZXDiagram) -> np.ndarray: 

255 for b in diag.get_boundary(): 

256 if diag.get_qtype(b) != QuantumType.Classical: 

257 raise ValueError( 

258 "Non-classical boundary vertex found. " 

259 "unitary_from_classical_diagram only supports diagrams with " 

260 "only classical boundaries" 

261 ) 

262 tensor = tensor_from_mixed_diagram(diag) 

263 return _format_tensor_as_unitary(diag, tensor) 

264 

265 

266def density_matrix_from_cptp_diagram(diag: ZXDiagram) -> np.ndarray: 

267 for b in diag.get_boundary(): 

268 if diag.get_qtype(b) != QuantumType.Quantum: 

269 raise ValueError( 

270 "Non-quantum boundary vertex found. " 

271 "density_matrix_from_cptp_diagram only supports diagrams with " 

272 "only quantum boundaries" 

273 ) 

274 tensor = tensor_from_mixed_diagram(diag) 

275 n_bounds = len(diag.get_boundary()) 

276 shape = (pow(2, n_bounds), pow(2, n_bounds)) 

277 # diag.to_doubled_diagram() in tensor_from_mixed_diagram will alternate 

278 # original boundary vertices and their conjugates 

279 indices = [2 * i for i in range(n_bounds)] + [2 * i + 1 for i in range(n_bounds)] 

280 reshaped = np.transpose(tensor, indices).reshape(shape) 

281 return reshaped.T 

282 

283 

284def fix_boundaries_to_binary_states( 

285 diag: ZXDiagram, vals: dict[ZXVert, int] 

286) -> ZXDiagram: 

287 new_diag = ZXDiagram(diag) 

288 b_lookup = dict(zip(diag.get_boundary(), new_diag.get_boundary())) 

289 for b, val in vals.items(): 

290 if diag.get_zxtype(b) not in _boundary_types: 

291 raise ValueError("Can only set states of boundary vertices") 

292 if val not in [0, 1]: 

293 raise ValueError("Can only fix boundary states to |0> and |1>.") 

294 new_b = b_lookup[b] 

295 qtype = diag.get_qtype(b) 

296 assert qtype is not None 

297 fix_b = new_diag.add_vertex(ZXType.XSpider, float(val), qtype) 

298 bw = new_diag.adj_wires(new_b)[0] 

299 adj = new_diag.other_end(bw, new_b) 

300 adj_p = dict(new_diag.get_wire_ends(bw))[adj] 

301 new_diag.add_wire( 

302 u=fix_b, v=adj, v_port=adj_p, type=new_diag.get_wire_type(bw), qtype=qtype 

303 ) 

304 new_diag.remove_vertex(new_b) 

305 new_diag.multiply_scalar(0.5 if qtype == QuantumType.Quantum else sqrt(0.5)) 

306 return new_diag 

307 

308 

309def fix_inputs_to_binary_state(diag: ZXDiagram, vals: list[int]) -> ZXDiagram: 

310 inputs = diag.get_boundary(type=ZXType.Input) 

311 if len(inputs) != len(vals): 

312 raise ValueError( 

313 f"Gave {len(vals)} values for {len(inputs)} inputs of ZXDiagram" 

314 ) 

315 val_dict = dict(zip(inputs, vals)) 

316 return fix_boundaries_to_binary_states(diag, val_dict) 

317 

318 

319def fix_outputs_to_binary_state(diag: ZXDiagram, vals: list[int]) -> ZXDiagram: 

320 outputs = diag.get_boundary(type=ZXType.Output) 

321 if len(outputs) != len(vals): 

322 raise ValueError( 

323 f"Gave {len(vals)} values for {len(outputs)} outputs of ZXDiagram" 

324 ) 

325 val_dict = dict(zip(outputs, vals)) 

326 return fix_boundaries_to_binary_states(diag, val_dict)