Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 90%
154 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-02 12:44 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-02 12:44 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15from collections import defaultdict
16from itertools import combinations
17from tempfile import NamedTemporaryFile
19import graphviz as gv # type: ignore
20import networkx as nx # type: ignore
22from pytket.circuit import Circuit
25class Graph:
26 """
27 A class for visualising a circuit as a directed acyclic graph (DAG).
29 Note: in order to use graph-rendering methods, such as
30 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed
31 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_
32 for instructions on how to install them.
33 """
35 def __init__(self, c: Circuit):
36 (
37 q_inputs,
38 c_inputs,
39 w_inputs,
40 q_outputs,
41 c_outputs,
42 w_outputs,
43 input_names,
44 output_names,
45 node_data,
46 edge_data,
47 ) = c._dag_data # noqa: SLF001
48 self.q_inputs = q_inputs
49 self.c_inputs = c_inputs
50 self.w_inputs = w_inputs
51 self.q_outputs = q_outputs
52 self.c_outputs = c_outputs
53 self.w_outputs = w_outputs
54 self.input_names = input_names
55 self.output_names = output_names
56 self.node_data = node_data
57 self.Gnx: nx.MultiDiGraph | None = None
58 self.G: gv.Digraph | None = None
59 self.Gqc: gv.Graph | None = None
60 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict(
61 list
62 )
63 self.port_counts: dict = defaultdict(int)
64 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data:
65 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type))
66 self.port_counts[(src_node, src_port)] += 1
68 def as_nx(self) -> nx.MultiDiGraph:
69 """
70 Return a logical representation of the circuit as a DAG.
72 :returns: Representation of the DAG
73 """
74 if self.Gnx is not None:
75 return self.Gnx
76 Gnx = nx.MultiDiGraph()
77 for node, desc in self.node_data.items():
78 Gnx.add_node(node, desc=desc)
79 for nodepair, portpairlist in self.edge_data.items():
80 src_node, tgt_node = nodepair
81 for src_port, tgt_port, edge_type in portpairlist:
82 Gnx.add_edge(
83 src_node,
84 tgt_node,
85 src_port=src_port,
86 tgt_port=tgt_port,
87 edge_type=edge_type,
88 )
90 # Add node IDs to edges
91 for edge in nx.topological_sort(nx.line_graph(Gnx)):
92 src_node, tgt_node, _ = edge
93 # List parent edges with matching port number
94 src_port = Gnx.edges[edge]["src_port"]
95 prev_edges = [
96 e
97 for e in Gnx.in_edges(src_node, keys=True)
98 if Gnx.edges[e]["tgt_port"] == src_port
99 ]
100 if not prev_edges:
101 # The source must be an input node
102 unit_id = src_node
103 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
104 else:
105 # The parent must be unique
106 assert len(prev_edges) == 1
107 prev_edge = prev_edges[0]
108 unit_id = Gnx.edges[prev_edge]["unit_id"]
109 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
111 # Remove unnecessary port attributes to avoid clutter:
112 for node in Gnx.nodes:
113 if Gnx.in_degree(node) == 1:
114 for edge in Gnx.in_edges(node, keys=True):
115 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}})
116 for edge in Gnx.out_edges(node, keys=True):
117 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}})
119 self.Gnx = Gnx
120 return Gnx
122 def get_DAG(self) -> gv.Digraph: # noqa: PLR0912, PLR0915
123 """
124 Return a visual representation of the DAG as a graphviz object.
126 :returns: Representation of the DAG
127 """
128 if self.G is not None: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 return self.G
130 G = gv.Digraph(
131 "Circuit",
132 strict=True,
133 )
134 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0")
135 q_color = "blue"
136 c_color = "slategray"
137 b_color = "gray"
138 w_color = "green"
139 gate_color = "lightblue"
140 boundary_cluster_attr = {
141 "style": "rounded, filled",
142 "color": "lightgrey",
143 "margin": "5",
144 }
145 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"}
146 with G.subgraph(name="cluster_q_inputs") as c:
147 c.attr(rank="source", **boundary_cluster_attr)
148 c.node_attr.update(shape="point", color=q_color)
149 for node in self.q_inputs:
150 c.node(
151 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
152 )
153 with G.subgraph(name="cluster_c_inputs") as c:
154 c.attr(rank="source", **boundary_cluster_attr)
155 c.node_attr.update(shape="point", color=c_color)
156 for node in self.c_inputs:
157 c.node(
158 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
159 )
160 with G.subgraph(name="cluster_w_inputs") as c:
161 c.attr(rank="source", **boundary_cluster_attr)
162 c.node_attr.update(shape="point", color=w_color)
163 for node in self.w_inputs: 163 ↛ 164line 163 didn't jump to line 164 because the loop on line 163 never started
164 c.node(
165 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
166 )
167 with G.subgraph(name="cluster_q_outputs") as c:
168 c.attr(rank="sink", **boundary_cluster_attr)
169 c.node_attr.update(shape="point", color=q_color)
170 for node in self.q_outputs:
171 c.node(
172 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
173 )
174 with G.subgraph(name="cluster_c_outputs") as c:
175 c.attr(rank="sink", **boundary_cluster_attr)
176 c.node_attr.update(shape="point", color=c_color)
177 for node in self.c_outputs:
178 c.node(
179 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
180 )
181 with G.subgraph(name="cluster_w_outputs") as c:
182 c.attr(rank="sink", **boundary_cluster_attr)
183 c.node_attr.update(shape="point", color=w_color)
184 for node in self.w_outputs: 184 ↛ 185line 184 didn't jump to line 185 because the loop on line 184 never started
185 c.node(
186 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
187 )
188 boundary_nodes = (
189 self.q_inputs
190 | self.c_inputs
191 | self.w_inputs
192 | self.q_outputs
193 | self.c_outputs
194 | self.w_outputs
195 )
196 Gnx = self.as_nx()
197 node_cluster_attr = {
198 "style": "rounded, filled",
199 "color": gate_color,
200 "fontname": "Times-Roman",
201 "fontsize": "10",
202 "margin": "5",
203 "lheight": "100",
204 }
205 port_node_attr = {
206 "shape": "point",
207 "weight": "2",
208 "fontname": "Helvetica",
209 "fontsize": "8",
210 }
211 for node, ndata in Gnx.nodes.items():
212 if node not in boundary_nodes:
213 with G.subgraph(name="cluster_" + str(node)) as c:
214 c.attr(label=ndata["desc"], **node_cluster_attr)
215 n_ports = Gnx.in_degree(node)
216 if n_ports == 1:
217 c.node(name=str((node, 0)), **port_node_attr)
218 else:
219 for i in range(n_ports):
220 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr)
221 edge_colors = {
222 "Quantum": q_color,
223 "Boolean": b_color,
224 "Classical": c_color,
225 "WASM": w_color,
226 }
227 edge_attr = {
228 "weight": "2",
229 "arrowhead": "vee",
230 "arrowsize": "0.2",
231 "headclip": "true",
232 "tailclip": "true",
233 }
234 for edge, edata in Gnx.edges.items():
235 src_node, tgt_node, _ = edge
236 src_port = edata["src_port"] or 0
237 tgt_port = edata["tgt_port"] or 0
238 edge_type = edata["edge_type"]
239 src_nodename = str((src_node, src_port))
240 tgt_nodename = str((tgt_node, tgt_port))
241 G.edge(
242 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr
243 )
244 self.G = G
245 return G
247 def save_DAG(self, name: str, fmt: str = "pdf") -> None:
248 """
249 Save an image of the DAG to a file.
251 The actual filename will be "<name>.<fmt>". A wide range of formats is
252 supported. See https://graphviz.org/doc/info/output.html.
254 :param name: Prefix of file name
255 :param fmt: File format, e.g. "pdf", "png", ...
256 """
257 G = self.get_DAG()
258 G.render(name, cleanup=True, format=fmt, quiet=True)
260 def view_DAG(self) -> str:
261 """
262 View the DAG.
264 This method creates a temporary file, and returns its filename so that the
265 caller may delete it afterwards.
267 :returns: filename of temporary created file
268 """
269 G = self.get_DAG()
270 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115
271 G.view(filename, quiet=True)
272 return filename
274 def get_qubit_graph(self) -> gv.Graph:
275 """
276 Return a visual representation of the qubit connectivity graph as a graphviz
277 object.
279 :returns: Representation of the qubit connectivity graph of the circuit
280 """
281 if self.Gqc is not None: 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true
282 return self.Gqc
283 Gnx = self.as_nx()
284 Gqcnx = nx.Graph()
285 for node in Gnx.nodes():
286 qubits = []
287 for e in Gnx.in_edges(node, keys=True):
288 unit_id = Gnx.edges[e]["unit_id"]
289 if unit_id in self.q_inputs:
290 qubits.append(unit_id)
292 Gqcnx.add_edges_from(combinations(qubits, 2))
293 G = gv.Graph(
294 "Qubit connectivity",
295 node_attr={
296 "shape": "circle",
297 "color": "blue",
298 "fontname": "Courier",
299 "fontsize": "10",
300 },
301 engine="neato",
302 )
303 G.edges(
304 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges()
305 )
306 self.Gqc = G
307 return G
309 def view_qubit_graph(self) -> str:
310 """
311 View the qubit connectivity graph.
313 This method creates a temporary file, and returns its filename so that the
314 caller may delete it afterwards.
316 :returns: filename of temporary created file
317 """
318 G = self.get_qubit_graph()
319 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115
320 G.view(filename, quiet=True)
321 return filename
323 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None:
324 """
325 Save an image of the qubit connectivity graph to a file.
327 The actual filename will be "<name>.<fmt>". A wide range of formats is
328 supported. See https://graphviz.org/doc/info/output.html.
330 :param name: Prefix of file name
331 :param fmt: File format, e.g. "pdf", "png", ...
332 """
333 G = self.get_qubit_graph()
334 G.render(name, cleanup=True, format=fmt, quiet=True)