Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 89%
167 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 07:40 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 07:40 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15from collections import defaultdict
16from itertools import combinations
17from tempfile import NamedTemporaryFile
19import graphviz as gv # type: ignore
20import networkx as nx # type: ignore
22from pytket.circuit import Circuit
25class Graph:
26 """
27 A class for visualising a circuit as a directed acyclic graph (DAG).
29 Note: in order to use graph-rendering methods, such as
30 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed
31 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_
32 for instructions on how to install them.
33 """
35 def __init__(self, c: Circuit):
36 (
37 q_inputs,
38 c_inputs,
39 w_inputs,
40 r_inputs,
41 q_outputs,
42 c_outputs,
43 w_outputs,
44 r_outputs,
45 input_names,
46 output_names,
47 node_data,
48 edge_data,
49 ) = c._dag_data # noqa: SLF001
50 self.q_inputs = q_inputs
51 self.c_inputs = c_inputs
52 self.w_inputs = w_inputs
53 self.r_inputs = r_inputs
54 self.q_outputs = q_outputs
55 self.c_outputs = c_outputs
56 self.w_outputs = w_outputs
57 self.r_outputs = r_outputs
58 self.input_names = input_names
59 self.output_names = output_names
60 self.node_data = node_data
61 self.Gnx: nx.MultiDiGraph | None = None
62 self.G: gv.Digraph | None = None
63 self.Gqc: gv.Graph | None = None
64 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict(
65 list
66 )
67 self.port_counts: dict = defaultdict(int)
68 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data:
69 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type))
70 self.port_counts[(src_node, src_port)] += 1
72 def as_nx(self) -> nx.MultiDiGraph:
73 """
74 Return a logical representation of the circuit as a DAG.
76 :returns: Representation of the DAG
77 """
78 if self.Gnx is not None:
79 return self.Gnx
80 Gnx = nx.MultiDiGraph()
81 for node, desc in self.node_data.items():
82 Gnx.add_node(node, desc=desc)
83 for nodepair, portpairlist in self.edge_data.items():
84 src_node, tgt_node = nodepair
85 for src_port, tgt_port, edge_type in portpairlist:
86 Gnx.add_edge(
87 src_node,
88 tgt_node,
89 src_port=src_port,
90 tgt_port=tgt_port,
91 edge_type=edge_type,
92 )
94 # Add node IDs to edges
95 for edge in nx.topological_sort(nx.line_graph(Gnx)):
96 src_node, tgt_node, _ = edge
97 # List parent edges with matching port number
98 src_port = Gnx.edges[edge]["src_port"]
99 prev_edges = [
100 e
101 for e in Gnx.in_edges(src_node, keys=True)
102 if Gnx.edges[e]["tgt_port"] == src_port
103 ]
104 if not prev_edges:
105 # The source must be an input node
106 unit_id = src_node
107 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
108 else:
109 # The parent must be unique
110 assert len(prev_edges) == 1
111 prev_edge = prev_edges[0]
112 unit_id = Gnx.edges[prev_edge]["unit_id"]
113 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
115 # Remove unnecessary port attributes to avoid clutter:
116 for node in Gnx.nodes:
117 if Gnx.in_degree(node) == 1:
118 for edge in Gnx.in_edges(node, keys=True):
119 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}})
120 for edge in Gnx.out_edges(node, keys=True):
121 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}})
123 self.Gnx = Gnx
124 return Gnx
126 def get_DAG(self) -> gv.Digraph: # noqa: PLR0912, PLR0915
127 """
128 Return a visual representation of the DAG as a graphviz object.
130 :returns: Representation of the DAG
131 """
132 if self.G is not None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 return self.G
134 G = gv.Digraph(
135 "Circuit",
136 strict=True,
137 )
138 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0")
139 q_color = "blue"
140 c_color = "slategray"
141 b_color = "gray"
142 w_color = "green"
143 r_color = "red"
144 gate_color = "lightblue"
145 boundary_cluster_attr = {
146 "style": "rounded, filled",
147 "color": "lightgrey",
148 "margin": "5",
149 }
150 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"}
151 with G.subgraph(name="cluster_q_inputs") as c:
152 c.attr(rank="source", **boundary_cluster_attr)
153 c.node_attr.update(shape="point", color=q_color)
154 for node in self.q_inputs:
155 c.node(
156 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
157 )
158 with G.subgraph(name="cluster_c_inputs") as c:
159 c.attr(rank="source", **boundary_cluster_attr)
160 c.node_attr.update(shape="point", color=c_color)
161 for node in self.c_inputs:
162 c.node(
163 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
164 )
165 with G.subgraph(name="cluster_w_inputs") as c:
166 c.attr(rank="source", **boundary_cluster_attr)
167 c.node_attr.update(shape="point", color=w_color)
168 for node in self.w_inputs: 168 ↛ 169line 168 didn't jump to line 169 because the loop on line 168 never started
169 c.node(
170 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
171 )
172 with G.subgraph(name="cluster_r_inputs") as c:
173 c.attr(rank="source", **boundary_cluster_attr)
174 c.node_attr.update(shape="point", color=r_color)
175 for node in self.r_inputs: 175 ↛ 176line 175 didn't jump to line 176 because the loop on line 175 never started
176 c.node(
177 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
178 )
179 with G.subgraph(name="cluster_q_outputs") as c:
180 c.attr(rank="sink", **boundary_cluster_attr)
181 c.node_attr.update(shape="point", color=q_color)
182 for node in self.q_outputs:
183 c.node(
184 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
185 )
186 with G.subgraph(name="cluster_c_outputs") as c:
187 c.attr(rank="sink", **boundary_cluster_attr)
188 c.node_attr.update(shape="point", color=c_color)
189 for node in self.c_outputs:
190 c.node(
191 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
192 )
193 with G.subgraph(name="cluster_w_outputs") as c:
194 c.attr(rank="sink", **boundary_cluster_attr)
195 c.node_attr.update(shape="point", color=w_color)
196 for node in self.w_outputs: 196 ↛ 197line 196 didn't jump to line 197 because the loop on line 196 never started
197 c.node(
198 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
199 )
200 with G.subgraph(name="cluster_r_outputs") as c:
201 c.attr(rank="sink", **boundary_cluster_attr)
202 c.node_attr.update(shape="point", color=r_color)
203 for node in self.r_outputs: 203 ↛ 204line 203 didn't jump to line 204 because the loop on line 203 never started
204 c.node(
205 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
206 )
207 boundary_nodes = (
208 self.q_inputs
209 | self.c_inputs
210 | self.w_inputs
211 | self.r_inputs
212 | self.q_outputs
213 | self.c_outputs
214 | self.w_outputs
215 | self.r_outputs
216 )
217 Gnx = self.as_nx()
218 node_cluster_attr = {
219 "style": "rounded, filled",
220 "color": gate_color,
221 "fontname": "Times-Roman",
222 "fontsize": "10",
223 "margin": "5",
224 "lheight": "100",
225 }
226 port_node_attr = {
227 "shape": "point",
228 "weight": "2",
229 "fontname": "Helvetica",
230 "fontsize": "8",
231 }
232 for node, ndata in Gnx.nodes.items():
233 if node not in boundary_nodes:
234 with G.subgraph(name="cluster_" + str(node)) as c:
235 c.attr(label=ndata["desc"], **node_cluster_attr)
236 n_ports = Gnx.in_degree(node)
237 if n_ports == 1:
238 c.node(name=str((node, 0)), **port_node_attr)
239 else:
240 for i in range(n_ports):
241 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr)
242 edge_colors = {
243 "Quantum": q_color,
244 "Boolean": b_color,
245 "Classical": c_color,
246 "WASM": w_color,
247 "RNG": r_color,
248 }
249 edge_attr = {
250 "weight": "2",
251 "arrowhead": "vee",
252 "arrowsize": "0.2",
253 "headclip": "true",
254 "tailclip": "true",
255 }
256 for edge, edata in Gnx.edges.items():
257 src_node, tgt_node, _ = edge
258 src_port = edata["src_port"] or 0
259 tgt_port = edata["tgt_port"] or 0
260 edge_type = edata["edge_type"]
261 src_nodename = str((src_node, src_port))
262 tgt_nodename = str((tgt_node, tgt_port))
263 G.edge(
264 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr
265 )
266 self.G = G
267 return G
269 def save_DAG(self, name: str, fmt: str = "pdf") -> None:
270 """
271 Save an image of the DAG to a file.
273 The actual filename will be "<name>.<fmt>". A wide range of formats is
274 supported. See https://graphviz.org/doc/info/output.html.
276 :param name: Prefix of file name
277 :param fmt: File format, e.g. "pdf", "png", ...
278 """
279 G = self.get_DAG()
280 G.render(name, cleanup=True, format=fmt, quiet=True)
282 def view_DAG(self) -> str:
283 """
284 View the DAG.
286 This method creates a temporary file, and returns its filename so that the
287 caller may delete it afterwards.
289 :returns: filename of temporary created file
290 """
291 G = self.get_DAG()
292 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115
293 G.view(filename, quiet=True)
294 return filename
296 def get_qubit_graph(self) -> gv.Graph:
297 """
298 Return a visual representation of the qubit connectivity graph as a graphviz
299 object.
301 :returns: Representation of the qubit connectivity graph of the circuit
302 """
303 if self.Gqc is not None: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 return self.Gqc
305 Gnx = self.as_nx()
306 Gqcnx = nx.Graph()
307 for node in Gnx.nodes():
308 qubits = []
309 for e in Gnx.in_edges(node, keys=True):
310 unit_id = Gnx.edges[e]["unit_id"]
311 if unit_id in self.q_inputs:
312 qubits.append(unit_id)
314 Gqcnx.add_edges_from(combinations(qubits, 2))
315 G = gv.Graph(
316 "Qubit connectivity",
317 node_attr={
318 "shape": "circle",
319 "color": "blue",
320 "fontname": "Courier",
321 "fontsize": "10",
322 },
323 engine="neato",
324 )
325 G.edges(
326 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges()
327 )
328 self.Gqc = G
329 return G
331 def view_qubit_graph(self) -> str:
332 """
333 View the qubit connectivity graph.
335 This method creates a temporary file, and returns its filename so that the
336 caller may delete it afterwards.
338 :returns: filename of temporary created file
339 """
340 G = self.get_qubit_graph()
341 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115
342 G.view(filename, quiet=True)
343 return filename
345 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None:
346 """
347 Save an image of the qubit connectivity graph to a file.
349 The actual filename will be "<name>.<fmt>". A wide range of formats is
350 supported. See https://graphviz.org/doc/info/output.html.
352 :param name: Prefix of file name
353 :param fmt: File format, e.g. "pdf", "png", ...
354 """
355 G = self.get_qubit_graph()
356 G.render(name, cleanup=True, format=fmt, quiet=True)