Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 90%

154 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 10:02 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import defaultdict 

16from itertools import combinations 

17from tempfile import NamedTemporaryFile 

18 

19import graphviz as gv # type: ignore 

20import networkx as nx # type: ignore 

21 

22from pytket.circuit import Circuit 

23 

24 

25class Graph: 

26 def __init__(self, c: Circuit): 

27 """ 

28 A class for visualising a circuit as a directed acyclic graph (DAG). 

29 

30 Note: in order to use graph-rendering methods, such as 

31 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed 

32 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_ 

33 for instructions on how to install them. 

34 

35 :param c: Circuit 

36 :type c: pytket.Circuit 

37 """ 

38 ( 

39 q_inputs, 

40 c_inputs, 

41 w_inputs, 

42 q_outputs, 

43 c_outputs, 

44 w_outputs, 

45 input_names, 

46 output_names, 

47 node_data, 

48 edge_data, 

49 ) = c._dag_data 

50 self.q_inputs = q_inputs 

51 self.c_inputs = c_inputs 

52 self.w_inputs = w_inputs 

53 self.q_outputs = q_outputs 

54 self.c_outputs = c_outputs 

55 self.w_outputs = w_outputs 

56 self.input_names = input_names 

57 self.output_names = output_names 

58 self.node_data = node_data 

59 self.Gnx: nx.MultiDiGraph | None = None 

60 self.G: gv.Digraph | None = None 

61 self.Gqc: gv.Graph | None = None 

62 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict( 

63 list 

64 ) 

65 self.port_counts: dict = defaultdict(int) 

66 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data: 

67 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type)) 

68 self.port_counts[(src_node, src_port)] += 1 

69 

70 def as_nx(self) -> nx.MultiDiGraph: 

71 """ 

72 Return a logical representation of the circuit as a DAG. 

73 

74 :returns: Representation of the DAG 

75 :rtype: networkx.MultiDiGraph 

76 """ 

77 if self.Gnx is not None: 

78 return self.Gnx 

79 Gnx = nx.MultiDiGraph() 

80 for node, desc in self.node_data.items(): 

81 Gnx.add_node(node, desc=desc) 

82 for nodepair, portpairlist in self.edge_data.items(): 

83 src_node, tgt_node = nodepair 

84 for src_port, tgt_port, edge_type in portpairlist: 

85 Gnx.add_edge( 

86 src_node, 

87 tgt_node, 

88 src_port=src_port, 

89 tgt_port=tgt_port, 

90 edge_type=edge_type, 

91 ) 

92 

93 # Add node IDs to edges 

94 for edge in nx.topological_sort(nx.line_graph(Gnx)): 

95 src_node, tgt_node, _ = edge 

96 # List parent edges with matching port number 

97 src_port = Gnx.edges[edge]["src_port"] 

98 prev_edges = [ 

99 e 

100 for e in Gnx.in_edges(src_node, keys=True) 

101 if Gnx.edges[e]["tgt_port"] == src_port 

102 ] 

103 if not prev_edges: 

104 # The source must be an input node 

105 unit_id = src_node 

106 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

107 else: 

108 # The parent must be unique 

109 assert len(prev_edges) == 1 

110 prev_edge = prev_edges[0] 

111 unit_id = Gnx.edges[prev_edge]["unit_id"] 

112 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) 

113 

114 # Remove unnecessary port attributes to avoid clutter: 

115 for node in Gnx.nodes: 

116 if Gnx.in_degree(node) == 1: 

117 for edge in Gnx.in_edges(node, keys=True): 

118 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}}) 

119 for edge in Gnx.out_edges(node, keys=True): 

120 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}}) 

121 

122 self.Gnx = Gnx 

123 return Gnx 

124 

125 def get_DAG(self) -> gv.Digraph: 

126 """ 

127 Return a visual representation of the DAG as a graphviz object. 

128 

129 :returns: Representation of the DAG 

130 :rtype: graphviz.DiGraph 

131 """ 

132 if self.G is not None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 return self.G 

134 G = gv.Digraph( 

135 "Circuit", 

136 strict=True, 

137 ) 

138 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0") 

139 q_color = "blue" 

140 c_color = "slategray" 

141 b_color = "gray" 

142 w_color = "green" 

143 gate_color = "lightblue" 

144 boundary_cluster_attr = { 

145 "style": "rounded, filled", 

146 "color": "lightgrey", 

147 "margin": "5", 

148 } 

149 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"} 

150 with G.subgraph(name="cluster_q_inputs") as c: 

151 c.attr(rank="source", **boundary_cluster_attr) 

152 c.node_attr.update(shape="point", color=q_color) 

153 for node in self.q_inputs: 

154 c.node( 

155 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

156 ) 

157 with G.subgraph(name="cluster_c_inputs") as c: 

158 c.attr(rank="source", **boundary_cluster_attr) 

159 c.node_attr.update(shape="point", color=c_color) 

160 for node in self.c_inputs: 

161 c.node( 

162 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

163 ) 

164 with G.subgraph(name="cluster_w_inputs") as c: 

165 c.attr(rank="source", **boundary_cluster_attr) 

166 c.node_attr.update(shape="point", color=w_color) 

167 for node in self.w_inputs: 167 ↛ 168line 167 didn't jump to line 168 because the loop on line 167 never started

168 c.node( 

169 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr 

170 ) 

171 with G.subgraph(name="cluster_q_outputs") as c: 

172 c.attr(rank="sink", **boundary_cluster_attr) 

173 c.node_attr.update(shape="point", color=q_color) 

174 for node in self.q_outputs: 

175 c.node( 

176 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

177 ) 

178 with G.subgraph(name="cluster_c_outputs") as c: 

179 c.attr(rank="sink", **boundary_cluster_attr) 

180 c.node_attr.update(shape="point", color=c_color) 

181 for node in self.c_outputs: 

182 c.node( 

183 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

184 ) 

185 with G.subgraph(name="cluster_w_outputs") as c: 

186 c.attr(rank="sink", **boundary_cluster_attr) 

187 c.node_attr.update(shape="point", color=w_color) 

188 for node in self.w_outputs: 188 ↛ 189line 188 didn't jump to line 189 because the loop on line 188 never started

189 c.node( 

190 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr 

191 ) 

192 boundary_nodes = ( 

193 self.q_inputs 

194 | self.c_inputs 

195 | self.w_inputs 

196 | self.q_outputs 

197 | self.c_outputs 

198 | self.w_outputs 

199 ) 

200 Gnx = self.as_nx() 

201 node_cluster_attr = { 

202 "style": "rounded, filled", 

203 "color": gate_color, 

204 "fontname": "Times-Roman", 

205 "fontsize": "10", 

206 "margin": "5", 

207 "lheight": "100", 

208 } 

209 port_node_attr = { 

210 "shape": "point", 

211 "weight": "2", 

212 "fontname": "Helvetica", 

213 "fontsize": "8", 

214 } 

215 for node, ndata in Gnx.nodes.items(): 

216 if node not in boundary_nodes: 

217 with G.subgraph(name="cluster_" + str(node)) as c: 

218 c.attr(label=ndata["desc"], **node_cluster_attr) 

219 n_ports = Gnx.in_degree(node) 

220 if n_ports == 1: 

221 c.node(name=str((node, 0)), **port_node_attr) 

222 else: 

223 for i in range(n_ports): 

224 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr) 

225 edge_colors = { 

226 "Quantum": q_color, 

227 "Boolean": b_color, 

228 "Classical": c_color, 

229 "WASM": w_color, 

230 } 

231 edge_attr = { 

232 "weight": "2", 

233 "arrowhead": "vee", 

234 "arrowsize": "0.2", 

235 "headclip": "true", 

236 "tailclip": "true", 

237 } 

238 for edge, edata in Gnx.edges.items(): 

239 src_node, tgt_node, _ = edge 

240 src_port = edata["src_port"] or 0 

241 tgt_port = edata["tgt_port"] or 0 

242 edge_type = edata["edge_type"] 

243 src_nodename = str((src_node, src_port)) 

244 tgt_nodename = str((tgt_node, tgt_port)) 

245 G.edge( 

246 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr 

247 ) 

248 self.G = G 

249 return G 

250 

251 def save_DAG(self, name: str, fmt: str = "pdf") -> None: 

252 """ 

253 Save an image of the DAG to a file. 

254 

255 The actual filename will be "<name>.<fmt>". A wide range of formats is 

256 supported. See https://graphviz.org/doc/info/output.html. 

257 

258 :param name: Prefix of file name 

259 :type name: str 

260 :param fmt: File format, e.g. "pdf", "png", ... 

261 :type fmt: str 

262 """ 

263 G = self.get_DAG() 

264 G.render(name, cleanup=True, format=fmt, quiet=True) 

265 

266 def view_DAG(self) -> str: 

267 """ 

268 View the DAG. 

269 

270 This method creates a temporary file, and returns its filename so that the 

271 caller may delete it afterwards. 

272 

273 :returns: filename of temporary created file 

274 """ 

275 G = self.get_DAG() 

276 filename = NamedTemporaryFile(delete=False).name 

277 G.view(filename, quiet=True) 

278 return filename 

279 

280 def get_qubit_graph(self) -> gv.Graph: 

281 """ 

282 Return a visual representation of the qubit connectivity graph as a graphviz 

283 object. 

284 

285 :returns: Representation of the qubit connectivity graph of the circuit 

286 :rtype: graphviz.Graph 

287 """ 

288 if self.Gqc is not None: 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true

289 return self.Gqc 

290 Gnx = self.as_nx() 

291 Gqcnx = nx.Graph() 

292 for node in Gnx.nodes(): 

293 qubits = [] 

294 for e in Gnx.in_edges(node, keys=True): 

295 unit_id = Gnx.edges[e]["unit_id"] 

296 if unit_id in self.q_inputs: 

297 qubits.append(unit_id) 

298 

299 Gqcnx.add_edges_from(combinations(qubits, 2)) 

300 G = gv.Graph( 

301 "Qubit connectivity", 

302 node_attr={ 

303 "shape": "circle", 

304 "color": "blue", 

305 "fontname": "Courier", 

306 "fontsize": "10", 

307 }, 

308 engine="neato", 

309 ) 

310 G.edges( 

311 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges() 

312 ) 

313 self.Gqc = G 

314 return G 

315 

316 def view_qubit_graph(self) -> str: 

317 """ 

318 View the qubit connectivity graph. 

319 

320 This method creates a temporary file, and returns its filename so that the 

321 caller may delete it afterwards. 

322 

323 :returns: filename of temporary created file 

324 """ 

325 G = self.get_qubit_graph() 

326 filename = NamedTemporaryFile(delete=False).name 

327 G.view(filename, quiet=True) 

328 return filename 

329 

330 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None: 

331 """ 

332 Save an image of the qubit connectivity graph to a file. 

333 

334 The actual filename will be "<name>.<fmt>". A wide range of formats is 

335 supported. See https://graphviz.org/doc/info/output.html. 

336 

337 :param name: Prefix of file name 

338 :type name: str 

339 :param fmt: File format, e.g. "pdf", "png", ... 

340 :type fmt: str 

341 """ 

342 G = self.get_qubit_graph() 

343 G.render(name, cleanup=True, format=fmt, quiet=True)