Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 90%

154 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-02 12:44 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import defaultdict 

16from itertools import combinations 

17from tempfile import NamedTemporaryFile 

18 

19import graphviz as gv # type: ignore 

20import networkx as nx # type: ignore 

21 

22from pytket.circuit import Circuit 

23 

24 

25class Graph: 

26 """ 

27 A class for visualising a circuit as a directed acyclic graph (DAG). 

28 

29 Note: in order to use graph-rendering methods, such as 

30 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed 

31 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_ 

32 for instructions on how to install them. 

33 """ 

34 

35 def __init__(self, c: Circuit): 

36 ( 

37 q_inputs, 

38 c_inputs, 

39 w_inputs, 

40 q_outputs, 

41 c_outputs, 

42 w_outputs, 

43 input_names, 

44 output_names, 

45 node_data, 

46 edge_data, 

47 ) = c._dag_data # noqa: SLF001 

48 self.q_inputs = q_inputs 

49 self.c_inputs = c_inputs 

50 self.w_inputs = w_inputs 

51 self.q_outputs = q_outputs 

52 self.c_outputs = c_outputs 

53 self.w_outputs = w_outputs 

54 self.input_names = input_names 

55 self.output_names = output_names 

56 self.node_data = node_data 

57 self.Gnx: nx.MultiDiGraph | None = None 

58 self.G: gv.Digraph | None = None 

59 self.Gqc: gv.Graph | None = None 

60 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict( 

61 list 

62 ) 

63 self.port_counts: dict = defaultdict(int) 

64 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data: 

65 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type)) 

66 self.port_counts[(src_node, src_port)] += 1 

67 

68 def as_nx(self) -> nx.MultiDiGraph: 

69 """ 

70 Return a logical representation of the circuit as a DAG. 

71 

72 :returns: Representation of the DAG 

73 """ 

74 if self.Gnx is not None: 

75 return self.Gnx 

76 Gnx = nx.MultiDiGraph() 

77 for node, desc in self.node_data.items(): 

78 Gnx.add_node(node, desc=desc) 

79 for nodepair, portpairlist in self.edge_data.items(): 

80 src_node, tgt_node = nodepair 

81 for src_port, tgt_port, edge_type in portpairlist: 

82 Gnx.add_edge( 

83 src_node, 

84 tgt_node, 

85 src_port=src_port, 

86 tgt_port=tgt_port, 

87 edge_type=edge_type, 

88 ) 

89 

90 # Add node IDs to edges 

91 for edge in nx.topological_sort(nx.line_graph(Gnx)): 

92 src_node, tgt_node, _ = edge 

93 # List parent edges with matching port number 

94 src_port = Gnx.edges[edge]["src_port"] 

95 prev_edges = [ 

96 e 

97 for e in Gnx.in_edges(src_node, keys=True) 

98 if Gnx.edges[e]["tgt_port"] == src_port 

99 ] 

100 if not prev_edges: 

101 # The source must be an input node 

102 unit_id = src_node 

103 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

104 else: 

105 # The parent must be unique 

106 assert len(prev_edges) == 1 

107 prev_edge = prev_edges[0] 

108 unit_id = Gnx.edges[prev_edge]["unit_id"] 

109 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

110 

111 # Remove unnecessary port attributes to avoid clutter: 

112 for node in Gnx.nodes: 

113 if Gnx.in_degree(node) == 1: 

114 for edge in Gnx.in_edges(node, keys=True): 

115 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}}) 

116 for edge in Gnx.out_edges(node, keys=True): 

117 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}}) 

118 

119 self.Gnx = Gnx 

120 return Gnx 

121 

122 def get_DAG(self) -> gv.Digraph: # noqa: PLR0912, PLR0915 

123 """ 

124 Return a visual representation of the DAG as a graphviz object. 

125 

126 :returns: Representation of the DAG 

127 """ 

128 if self.G is not None: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 return self.G 

130 G = gv.Digraph( 

131 "Circuit", 

132 strict=True, 

133 ) 

134 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0") 

135 q_color = "blue" 

136 c_color = "slategray" 

137 b_color = "gray" 

138 w_color = "green" 

139 gate_color = "lightblue" 

140 boundary_cluster_attr = { 

141 "style": "rounded, filled", 

142 "color": "lightgrey", 

143 "margin": "5", 

144 } 

145 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"} 

146 with G.subgraph(name="cluster_q_inputs") as c: 

147 c.attr(rank="source", **boundary_cluster_attr) 

148 c.node_attr.update(shape="point", color=q_color) 

149 for node in self.q_inputs: 

150 c.node( 

151 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

152 ) 

153 with G.subgraph(name="cluster_c_inputs") as c: 

154 c.attr(rank="source", **boundary_cluster_attr) 

155 c.node_attr.update(shape="point", color=c_color) 

156 for node in self.c_inputs: 

157 c.node( 

158 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

159 ) 

160 with G.subgraph(name="cluster_w_inputs") as c: 

161 c.attr(rank="source", **boundary_cluster_attr) 

162 c.node_attr.update(shape="point", color=w_color) 

163 for node in self.w_inputs: 163 ↛ 164line 163 didn't jump to line 164 because the loop on line 163 never started

164 c.node( 

165 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

166 ) 

167 with G.subgraph(name="cluster_q_outputs") as c: 

168 c.attr(rank="sink", **boundary_cluster_attr) 

169 c.node_attr.update(shape="point", color=q_color) 

170 for node in self.q_outputs: 

171 c.node( 

172 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

173 ) 

174 with G.subgraph(name="cluster_c_outputs") as c: 

175 c.attr(rank="sink", **boundary_cluster_attr) 

176 c.node_attr.update(shape="point", color=c_color) 

177 for node in self.c_outputs: 

178 c.node( 

179 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

180 ) 

181 with G.subgraph(name="cluster_w_outputs") as c: 

182 c.attr(rank="sink", **boundary_cluster_attr) 

183 c.node_attr.update(shape="point", color=w_color) 

184 for node in self.w_outputs: 184 ↛ 185line 184 didn't jump to line 185 because the loop on line 184 never started

185 c.node( 

186 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

187 ) 

188 boundary_nodes = ( 

189 self.q_inputs 

190 | self.c_inputs 

191 | self.w_inputs 

192 | self.q_outputs 

193 | self.c_outputs 

194 | self.w_outputs 

195 ) 

196 Gnx = self.as_nx() 

197 node_cluster_attr = { 

198 "style": "rounded, filled", 

199 "color": gate_color, 

200 "fontname": "Times-Roman", 

201 "fontsize": "10", 

202 "margin": "5", 

203 "lheight": "100", 

204 } 

205 port_node_attr = { 

206 "shape": "point", 

207 "weight": "2", 

208 "fontname": "Helvetica", 

209 "fontsize": "8", 

210 } 

211 for node, ndata in Gnx.nodes.items(): 

212 if node not in boundary_nodes: 

213 with G.subgraph(name="cluster_" + str(node)) as c: 

214 c.attr(label=ndata["desc"], **node_cluster_attr) 

215 n_ports = Gnx.in_degree(node) 

216 if n_ports == 1: 

217 c.node(name=str((node, 0)), **port_node_attr) 

218 else: 

219 for i in range(n_ports): 

220 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr) 

221 edge_colors = { 

222 "Quantum": q_color, 

223 "Boolean": b_color, 

224 "Classical": c_color, 

225 "WASM": w_color, 

226 } 

227 edge_attr = { 

228 "weight": "2", 

229 "arrowhead": "vee", 

230 "arrowsize": "0.2", 

231 "headclip": "true", 

232 "tailclip": "true", 

233 } 

234 for edge, edata in Gnx.edges.items(): 

235 src_node, tgt_node, _ = edge 

236 src_port = edata["src_port"] or 0 

237 tgt_port = edata["tgt_port"] or 0 

238 edge_type = edata["edge_type"] 

239 src_nodename = str((src_node, src_port)) 

240 tgt_nodename = str((tgt_node, tgt_port)) 

241 G.edge( 

242 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr 

243 ) 

244 self.G = G 

245 return G 

246 

247 def save_DAG(self, name: str, fmt: str = "pdf") -> None: 

248 """ 

249 Save an image of the DAG to a file. 

250 

251 The actual filename will be "<name>.<fmt>". A wide range of formats is 

252 supported. See https://graphviz.org/doc/info/output.html. 

253 

254 :param name: Prefix of file name 

255 :param fmt: File format, e.g. "pdf", "png", ... 

256 """ 

257 G = self.get_DAG() 

258 G.render(name, cleanup=True, format=fmt, quiet=True) 

259 

260 def view_DAG(self) -> str: 

261 """ 

262 View the DAG. 

263 

264 This method creates a temporary file, and returns its filename so that the 

265 caller may delete it afterwards. 

266 

267 :returns: filename of temporary created file 

268 """ 

269 G = self.get_DAG() 

270 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115 

271 G.view(filename, quiet=True) 

272 return filename 

273 

274 def get_qubit_graph(self) -> gv.Graph: 

275 """ 

276 Return a visual representation of the qubit connectivity graph as a graphviz 

277 object. 

278 

279 :returns: Representation of the qubit connectivity graph of the circuit 

280 """ 

281 if self.Gqc is not None: 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true

282 return self.Gqc 

283 Gnx = self.as_nx() 

284 Gqcnx = nx.Graph() 

285 for node in Gnx.nodes(): 

286 qubits = [] 

287 for e in Gnx.in_edges(node, keys=True): 

288 unit_id = Gnx.edges[e]["unit_id"] 

289 if unit_id in self.q_inputs: 

290 qubits.append(unit_id) 

291 

292 Gqcnx.add_edges_from(combinations(qubits, 2)) 

293 G = gv.Graph( 

294 "Qubit connectivity", 

295 node_attr={ 

296 "shape": "circle", 

297 "color": "blue", 

298 "fontname": "Courier", 

299 "fontsize": "10", 

300 }, 

301 engine="neato", 

302 ) 

303 G.edges( 

304 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges() 

305 ) 

306 self.Gqc = G 

307 return G 

308 

309 def view_qubit_graph(self) -> str: 

310 """ 

311 View the qubit connectivity graph. 

312 

313 This method creates a temporary file, and returns its filename so that the 

314 caller may delete it afterwards. 

315 

316 :returns: filename of temporary created file 

317 """ 

318 G = self.get_qubit_graph() 

319 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115 

320 G.view(filename, quiet=True) 

321 return filename 

322 

323 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None: 

324 """ 

325 Save an image of the qubit connectivity graph to a file. 

326 

327 The actual filename will be "<name>.<fmt>". A wide range of formats is 

328 supported. See https://graphviz.org/doc/info/output.html. 

329 

330 :param name: Prefix of file name 

331 :param fmt: File format, e.g. "pdf", "png", ... 

332 """ 

333 G = self.get_qubit_graph() 

334 G.render(name, cleanup=True, format=fmt, quiet=True)