Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 90%
154 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:08 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:08 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15from collections import defaultdict
16from itertools import combinations
17from tempfile import NamedTemporaryFile
19import graphviz as gv # type: ignore
20import networkx as nx # type: ignore
22from pytket.circuit import Circuit
25class Graph:
26 """
27 A class for visualising a circuit as a directed acyclic graph (DAG).
29 Note: in order to use graph-rendering methods, such as
30 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed
31 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_
32 for instructions on how to install them.
33 """
35 def __init__(self, c: Circuit):
36 (
37 q_inputs,
38 c_inputs,
39 w_inputs,
40 q_outputs,
41 c_outputs,
42 w_outputs,
43 input_names,
44 output_names,
45 node_data,
46 edge_data,
47 ) = c._dag_data # noqa: SLF001
48 self.q_inputs = q_inputs
49 self.c_inputs = c_inputs
50 self.w_inputs = w_inputs
51 self.q_outputs = q_outputs
52 self.c_outputs = c_outputs
53 self.w_outputs = w_outputs
54 self.input_names = input_names
55 self.output_names = output_names
56 self.node_data = node_data
57 self.Gnx: nx.MultiDiGraph | None = None
58 self.G: gv.Digraph | None = None
59 self.Gqc: gv.Graph | None = None
60 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict(
61 list
62 )
63 self.port_counts: dict = defaultdict(int)
64 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data:
65 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type))
66 self.port_counts[(src_node, src_port)] += 1
68 def as_nx(self) -> nx.MultiDiGraph:
69 """
70 Return a logical representation of the circuit as a DAG.
72 :returns: Representation of the DAG
73 :rtype: networkx.MultiDiGraph
74 """
75 if self.Gnx is not None:
76 return self.Gnx
77 Gnx = nx.MultiDiGraph()
78 for node, desc in self.node_data.items():
79 Gnx.add_node(node, desc=desc)
80 for nodepair, portpairlist in self.edge_data.items():
81 src_node, tgt_node = nodepair
82 for src_port, tgt_port, edge_type in portpairlist:
83 Gnx.add_edge(
84 src_node,
85 tgt_node,
86 src_port=src_port,
87 tgt_port=tgt_port,
88 edge_type=edge_type,
89 )
91 # Add node IDs to edges
92 for edge in nx.topological_sort(nx.line_graph(Gnx)):
93 src_node, tgt_node, _ = edge
94 # List parent edges with matching port number
95 src_port = Gnx.edges[edge]["src_port"]
96 prev_edges = [
97 e
98 for e in Gnx.in_edges(src_node, keys=True)
99 if Gnx.edges[e]["tgt_port"] == src_port
100 ]
101 if not prev_edges:
102 # The source must be an input node
103 unit_id = src_node
104 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
105 else:
106 # The parent must be unique
107 assert len(prev_edges) == 1
108 prev_edge = prev_edges[0]
109 unit_id = Gnx.edges[prev_edge]["unit_id"]
110 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
112 # Remove unnecessary port attributes to avoid clutter:
113 for node in Gnx.nodes:
114 if Gnx.in_degree(node) == 1:
115 for edge in Gnx.in_edges(node, keys=True):
116 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}})
117 for edge in Gnx.out_edges(node, keys=True):
118 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}})
120 self.Gnx = Gnx
121 return Gnx
123 def get_DAG(self) -> gv.Digraph: # noqa: PLR0912, PLR0915
124 """
125 Return a visual representation of the DAG as a graphviz object.
127 :returns: Representation of the DAG
128 :rtype: graphviz.DiGraph
129 """
130 if self.G is not None: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 return self.G
132 G = gv.Digraph(
133 "Circuit",
134 strict=True,
135 )
136 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0")
137 q_color = "blue"
138 c_color = "slategray"
139 b_color = "gray"
140 w_color = "green"
141 gate_color = "lightblue"
142 boundary_cluster_attr = {
143 "style": "rounded, filled",
144 "color": "lightgrey",
145 "margin": "5",
146 }
147 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"}
148 with G.subgraph(name="cluster_q_inputs") as c:
149 c.attr(rank="source", **boundary_cluster_attr)
150 c.node_attr.update(shape="point", color=q_color)
151 for node in self.q_inputs:
152 c.node(
153 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
154 )
155 with G.subgraph(name="cluster_c_inputs") as c:
156 c.attr(rank="source", **boundary_cluster_attr)
157 c.node_attr.update(shape="point", color=c_color)
158 for node in self.c_inputs:
159 c.node(
160 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
161 )
162 with G.subgraph(name="cluster_w_inputs") as c:
163 c.attr(rank="source", **boundary_cluster_attr)
164 c.node_attr.update(shape="point", color=w_color)
165 for node in self.w_inputs: 165 ↛ 166line 165 didn't jump to line 166 because the loop on line 165 never started
166 c.node(
167 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
168 )
169 with G.subgraph(name="cluster_q_outputs") as c:
170 c.attr(rank="sink", **boundary_cluster_attr)
171 c.node_attr.update(shape="point", color=q_color)
172 for node in self.q_outputs:
173 c.node(
174 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
175 )
176 with G.subgraph(name="cluster_c_outputs") as c:
177 c.attr(rank="sink", **boundary_cluster_attr)
178 c.node_attr.update(shape="point", color=c_color)
179 for node in self.c_outputs:
180 c.node(
181 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
182 )
183 with G.subgraph(name="cluster_w_outputs") as c:
184 c.attr(rank="sink", **boundary_cluster_attr)
185 c.node_attr.update(shape="point", color=w_color)
186 for node in self.w_outputs: 186 ↛ 187line 186 didn't jump to line 187 because the loop on line 186 never started
187 c.node(
188 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
189 )
190 boundary_nodes = (
191 self.q_inputs
192 | self.c_inputs
193 | self.w_inputs
194 | self.q_outputs
195 | self.c_outputs
196 | self.w_outputs
197 )
198 Gnx = self.as_nx()
199 node_cluster_attr = {
200 "style": "rounded, filled",
201 "color": gate_color,
202 "fontname": "Times-Roman",
203 "fontsize": "10",
204 "margin": "5",
205 "lheight": "100",
206 }
207 port_node_attr = {
208 "shape": "point",
209 "weight": "2",
210 "fontname": "Helvetica",
211 "fontsize": "8",
212 }
213 for node, ndata in Gnx.nodes.items():
214 if node not in boundary_nodes:
215 with G.subgraph(name="cluster_" + str(node)) as c:
216 c.attr(label=ndata["desc"], **node_cluster_attr)
217 n_ports = Gnx.in_degree(node)
218 if n_ports == 1:
219 c.node(name=str((node, 0)), **port_node_attr)
220 else:
221 for i in range(n_ports):
222 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr)
223 edge_colors = {
224 "Quantum": q_color,
225 "Boolean": b_color,
226 "Classical": c_color,
227 "WASM": w_color,
228 }
229 edge_attr = {
230 "weight": "2",
231 "arrowhead": "vee",
232 "arrowsize": "0.2",
233 "headclip": "true",
234 "tailclip": "true",
235 }
236 for edge, edata in Gnx.edges.items():
237 src_node, tgt_node, _ = edge
238 src_port = edata["src_port"] or 0
239 tgt_port = edata["tgt_port"] or 0
240 edge_type = edata["edge_type"]
241 src_nodename = str((src_node, src_port))
242 tgt_nodename = str((tgt_node, tgt_port))
243 G.edge(
244 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr
245 )
246 self.G = G
247 return G
249 def save_DAG(self, name: str, fmt: str = "pdf") -> None:
250 """
251 Save an image of the DAG to a file.
253 The actual filename will be "<name>.<fmt>". A wide range of formats is
254 supported. See https://graphviz.org/doc/info/output.html.
256 :param name: Prefix of file name
257 :type name: str
258 :param fmt: File format, e.g. "pdf", "png", ...
259 :type fmt: str
260 """
261 G = self.get_DAG()
262 G.render(name, cleanup=True, format=fmt, quiet=True)
264 def view_DAG(self) -> str:
265 """
266 View the DAG.
268 This method creates a temporary file, and returns its filename so that the
269 caller may delete it afterwards.
271 :returns: filename of temporary created file
272 """
273 G = self.get_DAG()
274 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115
275 G.view(filename, quiet=True)
276 return filename
278 def get_qubit_graph(self) -> gv.Graph:
279 """
280 Return a visual representation of the qubit connectivity graph as a graphviz
281 object.
283 :returns: Representation of the qubit connectivity graph of the circuit
284 :rtype: graphviz.Graph
285 """
286 if self.Gqc is not None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 return self.Gqc
288 Gnx = self.as_nx()
289 Gqcnx = nx.Graph()
290 for node in Gnx.nodes():
291 qubits = []
292 for e in Gnx.in_edges(node, keys=True):
293 unit_id = Gnx.edges[e]["unit_id"]
294 if unit_id in self.q_inputs:
295 qubits.append(unit_id)
297 Gqcnx.add_edges_from(combinations(qubits, 2))
298 G = gv.Graph(
299 "Qubit connectivity",
300 node_attr={
301 "shape": "circle",
302 "color": "blue",
303 "fontname": "Courier",
304 "fontsize": "10",
305 },
306 engine="neato",
307 )
308 G.edges(
309 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges()
310 )
311 self.Gqc = G
312 return G
314 def view_qubit_graph(self) -> str:
315 """
316 View the qubit connectivity graph.
318 This method creates a temporary file, and returns its filename so that the
319 caller may delete it afterwards.
321 :returns: filename of temporary created file
322 """
323 G = self.get_qubit_graph()
324 filename = NamedTemporaryFile(delete=False).name # noqa: SIM115
325 G.view(filename, quiet=True)
326 return filename
328 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None:
329 """
330 Save an image of the qubit connectivity graph to a file.
332 The actual filename will be "<name>.<fmt>". A wide range of formats is
333 supported. See https://graphviz.org/doc/info/output.html.
335 :param name: Prefix of file name
336 :type name: str
337 :param fmt: File format, e.g. "pdf", "png", ...
338 :type fmt: str
339 """
340 G = self.get_qubit_graph()
341 G.render(name, cleanup=True, format=fmt, quiet=True)