Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 89%

167 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 07:40 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import defaultdict 

16from itertools import combinations 

17from tempfile import NamedTemporaryFile 

18 

19import graphviz as gv # type: ignore 

20import networkx as nx # type: ignore 

21 

22from pytket.circuit import Circuit 

23 

24 

25class Graph: 

26 """ 

27 A class for visualising a circuit as a directed acyclic graph (DAG). 

28 

29 Note: in order to use graph-rendering methods, such as 

30 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed 

31 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_ 

32 for instructions on how to install them. 

33 """ 

34 

35 def __init__(self, c: Circuit): 

36 ( 

37 q_inputs, 

38 c_inputs, 

39 w_inputs, 

40 r_inputs, 

41 q_outputs, 

42 c_outputs, 

43 w_outputs, 

44 r_outputs, 

45 input_names, 

46 output_names, 

47 node_data, 

48 edge_data, 

49 ) = c._dag_data # noqa: SLF001 

50 self.q_inputs = q_inputs 

51 self.c_inputs = c_inputs 

52 self.w_inputs = w_inputs 

53 self.r_inputs = r_inputs 

54 self.q_outputs = q_outputs 

55 self.c_outputs = c_outputs 

56 self.w_outputs = w_outputs 

57 self.r_outputs = r_outputs 

58 self.input_names = input_names 

59 self.output_names = output_names 

60 self.node_data = node_data 

61 self.Gnx: nx.MultiDiGraph | None = None 

62 self.G: gv.Digraph | None = None 

63 self.Gqc: gv.Graph | None = None 

64 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict( 

65 list 

66 ) 

67 self.port_counts: dict = defaultdict(int) 

68 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data: 

69 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type)) 

70 self.port_counts[(src_node, src_port)] += 1 

71 

72 def as_nx(self) -> nx.MultiDiGraph: 

73 """ 

74 Return a logical representation of the circuit as a DAG. 

75 

76 :returns: Representation of the DAG 

77 """ 

78 if self.Gnx is not None: 

79 return self.Gnx 

80 Gnx = nx.MultiDiGraph() 

81 for node, desc in self.node_data.items(): 

82 Gnx.add_node(node, desc=desc) 

83 for nodepair, portpairlist in self.edge_data.items(): 

84 src_node, tgt_node = nodepair 

85 for src_port, tgt_port, edge_type in portpairlist: 

86 Gnx.add_edge( 

87 src_node, 

88 tgt_node, 

89 src_port=src_port, 

90 tgt_port=tgt_port, 

91 edge_type=edge_type, 

92 ) 

93 

94 # Add node IDs to edges 

95 for edge in nx.topological_sort(nx.line_graph(Gnx)): 

96 src_node, tgt_node, _ = edge 

97 # List parent edges with matching port number 

98 src_port = Gnx.edges[edge]["src_port"] 

99 prev_edges = [ 

100 e 

101 for e in Gnx.in_edges(src_node, keys=True) 

102 if Gnx.edges[e]["tgt_port"] == src_port 

103 ] 

104 if not prev_edges: 

105 # The source must be an input node 

106 unit_id = src_node 

107 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

108 else: 

109 # The parent must be unique 

110 assert len(prev_edges) == 1 

111 prev_edge = prev_edges[0] 

112 unit_id = Gnx.edges[prev_edge]["unit_id"] 

113 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

114 

115 # Remove unnecessary port attributes to avoid clutter: 

116 for node in Gnx.nodes: 

117 if Gnx.in_degree(node) == 1: 

118 for edge in Gnx.in_edges(node, keys=True): 

119 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}}) 

120 for edge in Gnx.out_edges(node, keys=True): 

121 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}}) 

122 

123 self.Gnx = Gnx 

124 return Gnx 

125 

126 def get_DAG(self) -> gv.Digraph: # noqa: PLR0912, PLR0915 

127 """ 

128 Return a visual representation of the DAG as a graphviz object. 

129 

130 :returns: Representation of the DAG 

131 """ 

132 if self.G is not None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 return self.G 

134 G = gv.Digraph( 

135 "Circuit", 

136 strict=True, 

137 ) 

138 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0") 

139 q_color = "blue" 

140 c_color = "slategray" 

141 b_color = "gray" 

142 w_color = "green" 

143 r_color = "red" 

144 gate_color = "lightblue" 

145 boundary_cluster_attr = { 

146 "style": "rounded, filled", 

147 "color": "lightgrey", 

148 "margin": "5", 

149 } 

150 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"} 

151 with G.subgraph(name="cluster_q_inputs") as c: 

152 c.attr(rank="source", **boundary_cluster_attr) 

153 c.node_attr.update(shape="point", color=q_color) 

154 for node in self.q_inputs: 

155 c.node( 

156 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

157 ) 

158 with G.subgraph(name="cluster_c_inputs") as c: 

159 c.attr(rank="source", **boundary_cluster_attr) 

160 c.node_attr.update(shape="point", color=c_color) 

161 for node in self.c_inputs: 

162 c.node( 

163 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

164 ) 

165 with G.subgraph(name="cluster_w_inputs") as c: 

166 c.attr(rank="source", **boundary_cluster_attr) 

167 c.node_attr.update(shape="point", color=w_color) 

168 for node in self.w_inputs: 168 ↛ 169line 168 didn't jump to line 169 because the loop on line 168 never started

169 c.node( 

170 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

171 ) 

172 with G.subgraph(name="cluster_r_inputs") as c: 

173 c.attr(rank="source", **boundary_cluster_attr) 

174 c.node_attr.update(shape="point", color=r_color) 

175 for node in self.r_inputs: 175 ↛ 176line 175 didn't jump to line 176 because the loop on line 175 never started

176 c.node( 

177 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

178 ) 

179 with G.subgraph(name="cluster_q_outputs") as c: 

180 c.attr(rank="sink", **boundary_cluster_attr) 

181 c.node_attr.update(shape="point", color=q_color) 

182 for node in self.q_outputs: 

183 c.node( 

184 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

185 ) 

186 with G.subgraph(name="cluster_c_outputs") as c: 

187 c.attr(rank="sink", **boundary_cluster_attr) 

188 c.node_attr.update(shape="point", color=c_color) 

189 for node in self.c_outputs: 

190 c.node( 

191 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

192 ) 

193 with G.subgraph(name="cluster_w_outputs") as c: 

194 c.attr(rank="sink", **boundary_cluster_attr) 

195 c.node_attr.update(shape="point", color=w_color) 

196 for node in self.w_outputs: 196 ↛ 197line 196 didn't jump to line 197 because the loop on line 196 never started

197 c.node( 

198 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

199 ) 

200 with G.subgraph(name="cluster_r_outputs") as c: 

201 c.attr(rank="sink", **boundary_cluster_attr) 

202 c.node_attr.update(shape="point", color=r_color) 

203 for node in self.r_outputs: 203 ↛ 204line 203 didn't jump to line 204 because the loop on line 203 never started

204 c.node( 

205 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

206 ) 

207 boundary_nodes = ( 

208 self.q_inputs 

209 | self.c_inputs 

210 | self.w_inputs 

211 | self.r_inputs 

212 | self.q_outputs 

213 | self.c_outputs 

214 | self.w_outputs 

215 | self.r_outputs 

216 ) 

217 Gnx = self.as_nx() 

218 node_cluster_attr = { 

219 "style": "rounded, filled", 

220 "color": gate_color, 

221 "fontname": "Times-Roman", 

222 "fontsize": "10", 

223 "margin": "5", 

224 "lheight": "100", 

225 } 

226 port_node_attr = { 

227 "shape": "point", 

228 "weight": "2", 

229 "fontname": "Helvetica", 

230 "fontsize": "8", 

231 } 

232 for node, ndata in Gnx.nodes.items(): 

233 if node not in boundary_nodes: 

234 with G.subgraph(name="cluster_" + str(node)) as c: 

235 c.attr(label=ndata["desc"], **node_cluster_attr) 

236 n_ports = Gnx.in_degree(node) 

237 if n_ports == 1: 

238 c.node(name=str((node, 0)), **port_node_attr) 

239 else: 

240 for i in range(n_ports): 

241 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr) 

242 edge_colors = { 

243 "Quantum": q_color, 

244 "Boolean": b_color, 

245 "Classical": c_color, 

246 "WASM": w_color, 

247 "RNG": r_color, 

248 } 

249 edge_attr = { 

250 "weight": "2", 

251 "arrowhead": "vee", 

252 "arrowsize": "0.2", 

253 "headclip": "true", 

254 "tailclip": "true", 

255 } 

256 for edge, edata in Gnx.edges.items(): 

257 src_node, tgt_node, _ = edge 

258 src_port = edata["src_port"] or 0 

259 tgt_port = edata["tgt_port"] or 0 

260 edge_type = edata["edge_type"] 

261 src_nodename = str((src_node, src_port)) 

262 tgt_nodename = str((tgt_node, tgt_port)) 

263 G.edge( 

264 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr 

265 ) 

266 self.G = G 

267 return G 

268 

269 def save_DAG(self, name: str, fmt: str = "pdf") -> None: 

270 """ 

271 Save an image of the DAG to a file. 

272 

273 The actual filename will be "<name>.<fmt>". A wide range of formats is 

274 supported. See https://graphviz.org/doc/info/output.html. 

275 

276 :param name: Prefix of file name 

277 :param fmt: File format, e.g. "pdf", "png", ... 

278 """ 

279 G = self.get_DAG() 

280 G.render(name, cleanup=True, format=fmt, quiet=True) 

281 

282 def view_DAG(self) -> str: 

283 """ 

284 View the DAG. 

285 

286 This method creates a temporary file, and returns its filename so that the 

287 caller may delete it afterwards. 

288 

289 :returns: filename of temporary created file 

290 """ 

291 G = self.get_DAG() 

292 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115 

293 G.view(filename, quiet=True) 

294 return filename 

295 

296 def get_qubit_graph(self) -> gv.Graph: 

297 """ 

298 Return a visual representation of the qubit connectivity graph as a graphviz 

299 object. 

300 

301 :returns: Representation of the qubit connectivity graph of the circuit 

302 """ 

303 if self.Gqc is not None: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 return self.Gqc 

305 Gnx = self.as_nx() 

306 Gqcnx = nx.Graph() 

307 for node in Gnx.nodes(): 

308 qubits = [] 

309 for e in Gnx.in_edges(node, keys=True): 

310 unit_id = Gnx.edges[e]["unit_id"] 

311 if unit_id in self.q_inputs: 

312 qubits.append(unit_id) 

313 

314 Gqcnx.add_edges_from(combinations(qubits, 2)) 

315 G = gv.Graph( 

316 "Qubit connectivity", 

317 node_attr={ 

318 "shape": "circle", 

319 "color": "blue", 

320 "fontname": "Courier", 

321 "fontsize": "10", 

322 }, 

323 engine="neato", 

324 ) 

325 G.edges( 

326 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges() 

327 ) 

328 self.Gqc = G 

329 return G 

330 

331 def view_qubit_graph(self) -> str: 

332 """ 

333 View the qubit connectivity graph. 

334 

335 This method creates a temporary file, and returns its filename so that the 

336 caller may delete it afterwards. 

337 

338 :returns: filename of temporary created file 

339 """ 

340 G = self.get_qubit_graph() 

341 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115 

342 G.view(filename, quiet=True) 

343 return filename 

344 

345 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None: 

346 """ 

347 Save an image of the qubit connectivity graph to a file. 

348 

349 The actual filename will be "<name>.<fmt>". A wide range of formats is 

350 supported. See https://graphviz.org/doc/info/output.html. 

351 

352 :param name: Prefix of file name 

353 :param fmt: File format, e.g. "pdf", "png", ... 

354 """ 

355 G = self.get_qubit_graph() 

356 G.render(name, cleanup=True, format=fmt, quiet=True)