Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 90%

154 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 15:08 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import defaultdict 

16from itertools import combinations 

17from tempfile import NamedTemporaryFile 

18 

19import graphviz as gv # type: ignore 

20import networkx as nx # type: ignore 

21 

22from pytket.circuit import Circuit 

23 

24 

25class Graph: 

26 """ 

27 A class for visualising a circuit as a directed acyclic graph (DAG). 

28 

29 Note: in order to use graph-rendering methods, such as 

30 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed 

31 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_ 

32 for instructions on how to install them. 

33 """ 

34 

35 def __init__(self, c: Circuit): 

36 ( 

37 q_inputs, 

38 c_inputs, 

39 w_inputs, 

40 q_outputs, 

41 c_outputs, 

42 w_outputs, 

43 input_names, 

44 output_names, 

45 node_data, 

46 edge_data, 

47 ) = c._dag_data # noqa: SLF001 

48 self.q_inputs = q_inputs 

49 self.c_inputs = c_inputs 

50 self.w_inputs = w_inputs 

51 self.q_outputs = q_outputs 

52 self.c_outputs = c_outputs 

53 self.w_outputs = w_outputs 

54 self.input_names = input_names 

55 self.output_names = output_names 

56 self.node_data = node_data 

57 self.Gnx: nx.MultiDiGraph | None = None 

58 self.G: gv.Digraph | None = None 

59 self.Gqc: gv.Graph | None = None 

60 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict( 

61 list 

62 ) 

63 self.port_counts: dict = defaultdict(int) 

64 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data: 

65 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type)) 

66 self.port_counts[(src_node, src_port)] += 1 

67 

68 def as_nx(self) -> nx.MultiDiGraph: 

69 """ 

70 Return a logical representation of the circuit as a DAG. 

71 

72 :returns: Representation of the DAG 

73 :rtype: networkx.MultiDiGraph 

74 """ 

75 if self.Gnx is not None: 

76 return self.Gnx 

77 Gnx = nx.MultiDiGraph() 

78 for node, desc in self.node_data.items(): 

79 Gnx.add_node(node, desc=desc) 

80 for nodepair, portpairlist in self.edge_data.items(): 

81 src_node, tgt_node = nodepair 

82 for src_port, tgt_port, edge_type in portpairlist: 

83 Gnx.add_edge( 

84 src_node, 

85 tgt_node, 

86 src_port=src_port, 

87 tgt_port=tgt_port, 

88 edge_type=edge_type, 

89 ) 

90 

91 # Add node IDs to edges 

92 for edge in nx.topological_sort(nx.line_graph(Gnx)): 

93 src_node, tgt_node, _ = edge 

94 # List parent edges with matching port number 

95 src_port = Gnx.edges[edge]["src_port"] 

96 prev_edges = [ 

97 e 

98 for e in Gnx.in_edges(src_node, keys=True) 

99 if Gnx.edges[e]["tgt_port"] == src_port 

100 ] 

101 if not prev_edges: 

102 # The source must be an input node 

103 unit_id = src_node 

104 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

105 else: 

106 # The parent must be unique 

107 assert len(prev_edges) == 1 

108 prev_edge = prev_edges[0] 

109 unit_id = Gnx.edges[prev_edge]["unit_id"] 

110 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

111 

112 # Remove unnecessary port attributes to avoid clutter: 

113 for node in Gnx.nodes: 

114 if Gnx.in_degree(node) == 1: 

115 for edge in Gnx.in_edges(node, keys=True): 

116 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}}) 

117 for edge in Gnx.out_edges(node, keys=True): 

118 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}}) 

119 

120 self.Gnx = Gnx 

121 return Gnx 

122 

123 def get_DAG(self) -> gv.Digraph: # noqa: PLR0912, PLR0915 

124 """ 

125 Return a visual representation of the DAG as a graphviz object. 

126 

127 :returns: Representation of the DAG 

128 :rtype: graphviz.DiGraph 

129 """ 

130 if self.G is not None: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 return self.G 

132 G = gv.Digraph( 

133 "Circuit", 

134 strict=True, 

135 ) 

136 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0") 

137 q_color = "blue" 

138 c_color = "slategray" 

139 b_color = "gray" 

140 w_color = "green" 

141 gate_color = "lightblue" 

142 boundary_cluster_attr = { 

143 "style": "rounded, filled", 

144 "color": "lightgrey", 

145 "margin": "5", 

146 } 

147 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"} 

148 with G.subgraph(name="cluster_q_inputs") as c: 

149 c.attr(rank="source", **boundary_cluster_attr) 

150 c.node_attr.update(shape="point", color=q_color) 

151 for node in self.q_inputs: 

152 c.node( 

153 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

154 ) 

155 with G.subgraph(name="cluster_c_inputs") as c: 

156 c.attr(rank="source", **boundary_cluster_attr) 

157 c.node_attr.update(shape="point", color=c_color) 

158 for node in self.c_inputs: 

159 c.node( 

160 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

161 ) 

162 with G.subgraph(name="cluster_w_inputs") as c: 

163 c.attr(rank="source", **boundary_cluster_attr) 

164 c.node_attr.update(shape="point", color=w_color) 

165 for node in self.w_inputs: 165 ↛ 166line 165 didn't jump to line 166 because the loop on line 165 never started

166 c.node( 

167 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

168 ) 

169 with G.subgraph(name="cluster_q_outputs") as c: 

170 c.attr(rank="sink", **boundary_cluster_attr) 

171 c.node_attr.update(shape="point", color=q_color) 

172 for node in self.q_outputs: 

173 c.node( 

174 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

175 ) 

176 with G.subgraph(name="cluster_c_outputs") as c: 

177 c.attr(rank="sink", **boundary_cluster_attr) 

178 c.node_attr.update(shape="point", color=c_color) 

179 for node in self.c_outputs: 

180 c.node( 

181 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

182 ) 

183 with G.subgraph(name="cluster_w_outputs") as c: 

184 c.attr(rank="sink", **boundary_cluster_attr) 

185 c.node_attr.update(shape="point", color=w_color) 

186 for node in self.w_outputs: 186 ↛ 187line 186 didn't jump to line 187 because the loop on line 186 never started

187 c.node( 

188 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

189 ) 

190 boundary_nodes = ( 

191 self.q_inputs 

192 | self.c_inputs 

193 | self.w_inputs 

194 | self.q_outputs 

195 | self.c_outputs 

196 | self.w_outputs 

197 ) 

198 Gnx = self.as_nx() 

199 node_cluster_attr = { 

200 "style": "rounded, filled", 

201 "color": gate_color, 

202 "fontname": "Times-Roman", 

203 "fontsize": "10", 

204 "margin": "5", 

205 "lheight": "100", 

206 } 

207 port_node_attr = { 

208 "shape": "point", 

209 "weight": "2", 

210 "fontname": "Helvetica", 

211 "fontsize": "8", 

212 } 

213 for node, ndata in Gnx.nodes.items(): 

214 if node not in boundary_nodes: 

215 with G.subgraph(name="cluster_" + str(node)) as c: 

216 c.attr(label=ndata["desc"], **node_cluster_attr) 

217 n_ports = Gnx.in_degree(node) 

218 if n_ports == 1: 

219 c.node(name=str((node, 0)), **port_node_attr) 

220 else: 

221 for i in range(n_ports): 

222 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr) 

223 edge_colors = { 

224 "Quantum": q_color, 

225 "Boolean": b_color, 

226 "Classical": c_color, 

227 "WASM": w_color, 

228 } 

229 edge_attr = { 

230 "weight": "2", 

231 "arrowhead": "vee", 

232 "arrowsize": "0.2", 

233 "headclip": "true", 

234 "tailclip": "true", 

235 } 

236 for edge, edata in Gnx.edges.items(): 

237 src_node, tgt_node, _ = edge 

238 src_port = edata["src_port"] or 0 

239 tgt_port = edata["tgt_port"] or 0 

240 edge_type = edata["edge_type"] 

241 src_nodename = str((src_node, src_port)) 

242 tgt_nodename = str((tgt_node, tgt_port)) 

243 G.edge( 

244 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr 

245 ) 

246 self.G = G 

247 return G 

248 

249 def save_DAG(self, name: str, fmt: str = "pdf") -> None: 

250 """ 

251 Save an image of the DAG to a file. 

252 

253 The actual filename will be "<name>.<fmt>". A wide range of formats is 

254 supported. See https://graphviz.org/doc/info/output.html. 

255 

256 :param name: Prefix of file name 

257 :type name: str 

258 :param fmt: File format, e.g. "pdf", "png", ... 

259 :type fmt: str 

260 """ 

261 G = self.get_DAG() 

262 G.render(name, cleanup=True, format=fmt, quiet=True) 

263 

264 def view_DAG(self) -> str: 

265 """ 

266 View the DAG. 

267 

268 This method creates a temporary file, and returns its filename so that the 

269 caller may delete it afterwards. 

270 

271 :returns: filename of temporary created file 

272 """ 

273 G = self.get_DAG() 

274 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115 

275 G.view(filename, quiet=True) 

276 return filename 

277 

278 def get_qubit_graph(self) -> gv.Graph: 

279 """ 

280 Return a visual representation of the qubit connectivity graph as a graphviz 

281 object. 

282 

283 :returns: Representation of the qubit connectivity graph of the circuit 

284 :rtype: graphviz.Graph 

285 """ 

286 if self.Gqc is not None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true

287 return self.Gqc 

288 Gnx = self.as_nx() 

289 Gqcnx = nx.Graph() 

290 for node in Gnx.nodes(): 

291 qubits = [] 

292 for e in Gnx.in_edges(node, keys=True): 

293 unit_id = Gnx.edges[e]["unit_id"] 

294 if unit_id in self.q_inputs: 

295 qubits.append(unit_id) 

296 

297 Gqcnx.add_edges_from(combinations(qubits, 2)) 

298 G = gv.Graph( 

299 "Qubit connectivity", 

300 node_attr={ 

301 "shape": "circle", 

302 "color": "blue", 

303 "fontname": "Courier", 

304 "fontsize": "10", 

305 }, 

306 engine="neato", 

307 ) 

308 G.edges( 

309 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges() 

310 ) 

311 self.Gqc = G 

312 return G 

313 

314 def view_qubit_graph(self) -> str: 

315 """ 

316 View the qubit connectivity graph. 

317 

318 This method creates a temporary file, and returns its filename so that the 

319 caller may delete it afterwards. 

320 

321 :returns: filename of temporary created file 

322 """ 

323 G = self.get_qubit_graph() 

324 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115 

325 G.view(filename, quiet=True) 

326 return filename 

327 

328 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None: 

329 """ 

330 Save an image of the qubit connectivity graph to a file. 

331 

332 The actual filename will be "<name>.<fmt>". A wide range of formats is 

333 supported. See https://graphviz.org/doc/info/output.html. 

334 

335 :param name: Prefix of file name 

336 :type name: str 

337 :param fmt: File format, e.g. "pdf", "png", ... 

338 :type fmt: str 

339 """ 

340 G = self.get_qubit_graph() 

341 G.render(name, cleanup=True, format=fmt, quiet=True)