Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/graph.py: 90%
154 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15from collections import defaultdict
16from itertools import combinations
17from tempfile import NamedTemporaryFile
19import graphviz as gv # type: ignore
20import networkx as nx # type: ignore
22from pytket.circuit import Circuit
25class Graph:
26 def __init__(self, c: Circuit):
27 """
28 A class for visualising a circuit as a directed acyclic graph (DAG).
30 Note: in order to use graph-rendering methods, such as
31 :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed
32 and on your path. See the `Graphviz website <https://graphviz.org/download/>`_
33 for instructions on how to install them.
35 :param c: Circuit
36 :type c: pytket.Circuit
37 """
38 (
39 q_inputs,
40 c_inputs,
41 w_inputs,
42 q_outputs,
43 c_outputs,
44 w_outputs,
45 input_names,
46 output_names,
47 node_data,
48 edge_data,
49 ) = c._dag_data
50 self.q_inputs = q_inputs
51 self.c_inputs = c_inputs
52 self.w_inputs = w_inputs
53 self.q_outputs = q_outputs
54 self.c_outputs = c_outputs
55 self.w_outputs = w_outputs
56 self.input_names = input_names
57 self.output_names = output_names
58 self.node_data = node_data
59 self.Gnx: nx.MultiDiGraph | None = None
60 self.G: gv.Digraph | None = None
61 self.Gqc: gv.Graph | None = None
62 self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict(
63 list
64 )
65 self.port_counts: dict = defaultdict(int)
66 for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data:
67 self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type))
68 self.port_counts[(src_node, src_port)] += 1
70 def as_nx(self) -> nx.MultiDiGraph:
71 """
72 Return a logical representation of the circuit as a DAG.
74 :returns: Representation of the DAG
75 :rtype: networkx.MultiDiGraph
76 """
77 if self.Gnx is not None:
78 return self.Gnx
79 Gnx = nx.MultiDiGraph()
80 for node, desc in self.node_data.items():
81 Gnx.add_node(node, desc=desc)
82 for nodepair, portpairlist in self.edge_data.items():
83 src_node, tgt_node = nodepair
84 for src_port, tgt_port, edge_type in portpairlist:
85 Gnx.add_edge(
86 src_node,
87 tgt_node,
88 src_port=src_port,
89 tgt_port=tgt_port,
90 edge_type=edge_type,
91 )
93 # Add node IDs to edges
94 for edge in nx.topological_sort(nx.line_graph(Gnx)):
95 src_node, tgt_node, _ = edge
96 # List parent edges with matching port number
97 src_port = Gnx.edges[edge]["src_port"]
98 prev_edges = [
99 e
100 for e in Gnx.in_edges(src_node, keys=True)
101 if Gnx.edges[e]["tgt_port"] == src_port
102 ]
103 if not prev_edges:
104 # The source must be an input node
105 unit_id = src_node
106 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
107 else:
108 # The parent must be unique
109 assert len(prev_edges) == 1
110 prev_edge = prev_edges[0]
111 unit_id = Gnx.edges[prev_edge]["unit_id"]
112 nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}})
114 # Remove unnecessary port attributes to avoid clutter:
115 for node in Gnx.nodes:
116 if Gnx.in_degree(node) == 1:
117 for edge in Gnx.in_edges(node, keys=True):
118 nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}})
119 for edge in Gnx.out_edges(node, keys=True):
120 nx.set_edge_attributes(Gnx, {edge: {"src_port": None}})
122 self.Gnx = Gnx
123 return Gnx
125 def get_DAG(self) -> gv.Digraph:
126 """
127 Return a visual representation of the DAG as a graphviz object.
129 :returns: Representation of the DAG
130 :rtype: graphviz.DiGraph
131 """
132 if self.G is not None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 return self.G
134 G = gv.Digraph(
135 "Circuit",
136 strict=True,
137 )
138 G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0")
139 q_color = "blue"
140 c_color = "slategray"
141 b_color = "gray"
142 w_color = "green"
143 gate_color = "lightblue"
144 boundary_cluster_attr = {
145 "style": "rounded, filled",
146 "color": "lightgrey",
147 "margin": "5",
148 }
149 boundary_node_attr = {"fontname": "Courier", "fontsize": "8"}
150 with G.subgraph(name="cluster_q_inputs") as c:
151 c.attr(rank="source", **boundary_cluster_attr)
152 c.node_attr.update(shape="point", color=q_color)
153 for node in self.q_inputs:
154 c.node(
155 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
156 )
157 with G.subgraph(name="cluster_c_inputs") as c:
158 c.attr(rank="source", **boundary_cluster_attr)
159 c.node_attr.update(shape="point", color=c_color)
160 for node in self.c_inputs:
161 c.node(
162 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
163 )
164 with G.subgraph(name="cluster_w_inputs") as c:
165 c.attr(rank="source", **boundary_cluster_attr)
166 c.node_attr.update(shape="point", color=w_color)
167 for node in self.w_inputs: 167 ↛ 168line 167 didn't jump to line 168 because the loop on line 167 never started
168 c.node(
169 str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr
170 )
171 with G.subgraph(name="cluster_q_outputs") as c:
172 c.attr(rank="sink", **boundary_cluster_attr)
173 c.node_attr.update(shape="point", color=q_color)
174 for node in self.q_outputs:
175 c.node(
176 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
177 )
178 with G.subgraph(name="cluster_c_outputs") as c:
179 c.attr(rank="sink", **boundary_cluster_attr)
180 c.node_attr.update(shape="point", color=c_color)
181 for node in self.c_outputs:
182 c.node(
183 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
184 )
185 with G.subgraph(name="cluster_w_outputs") as c:
186 c.attr(rank="sink", **boundary_cluster_attr)
187 c.node_attr.update(shape="point", color=w_color)
188 for node in self.w_outputs: 188 ↛ 189line 188 didn't jump to line 189 because the loop on line 188 never started
189 c.node(
190 str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr
191 )
192 boundary_nodes = (
193 self.q_inputs
194 | self.c_inputs
195 | self.w_inputs
196 | self.q_outputs
197 | self.c_outputs
198 | self.w_outputs
199 )
200 Gnx = self.as_nx()
201 node_cluster_attr = {
202 "style": "rounded, filled",
203 "color": gate_color,
204 "fontname": "Times-Roman",
205 "fontsize": "10",
206 "margin": "5",
207 "lheight": "100",
208 }
209 port_node_attr = {
210 "shape": "point",
211 "weight": "2",
212 "fontname": "Helvetica",
213 "fontsize": "8",
214 }
215 for node, ndata in Gnx.nodes.items():
216 if node not in boundary_nodes:
217 with G.subgraph(name="cluster_" + str(node)) as c:
218 c.attr(label=ndata["desc"], **node_cluster_attr)
219 n_ports = Gnx.in_degree(node)
220 if n_ports == 1:
221 c.node(name=str((node, 0)), **port_node_attr)
222 else:
223 for i in range(n_ports):
224 c.node(name=str((node, i)), xlabel=str(i), **port_node_attr)
225 edge_colors = {
226 "Quantum": q_color,
227 "Boolean": b_color,
228 "Classical": c_color,
229 "WASM": w_color,
230 }
231 edge_attr = {
232 "weight": "2",
233 "arrowhead": "vee",
234 "arrowsize": "0.2",
235 "headclip": "true",
236 "tailclip": "true",
237 }
238 for edge, edata in Gnx.edges.items():
239 src_node, tgt_node, _ = edge
240 src_port = edata["src_port"] or 0
241 tgt_port = edata["tgt_port"] or 0
242 edge_type = edata["edge_type"]
243 src_nodename = str((src_node, src_port))
244 tgt_nodename = str((tgt_node, tgt_port))
245 G.edge(
246 src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr
247 )
248 self.G = G
249 return G
251 def save_DAG(self, name: str, fmt: str = "pdf") -> None:
252 """
253 Save an image of the DAG to a file.
255 The actual filename will be "<name>.<fmt>". A wide range of formats is
256 supported. See https://graphviz.org/doc/info/output.html.
258 :param name: Prefix of file name
259 :type name: str
260 :param fmt: File format, e.g. "pdf", "png", ...
261 :type fmt: str
262 """
263 G = self.get_DAG()
264 G.render(name, cleanup=True, format=fmt, quiet=True)
266 def view_DAG(self) -> str:
267 """
268 View the DAG.
270 This method creates a temporary file, and returns its filename so that the
271 caller may delete it afterwards.
273 :returns: filename of temporary created file
274 """
275 G = self.get_DAG()
276 filename = NamedTemporaryFile(delete=False).name
277 G.view(filename, quiet=True)
278 return filename
280 def get_qubit_graph(self) -> gv.Graph:
281 """
282 Return a visual representation of the qubit connectivity graph as a graphviz
283 object.
285 :returns: Representation of the qubit connectivity graph of the circuit
286 :rtype: graphviz.Graph
287 """
288 if self.Gqc is not None: 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true
289 return self.Gqc
290 Gnx = self.as_nx()
291 Gqcnx = nx.Graph()
292 for node in Gnx.nodes():
293 qubits = []
294 for e in Gnx.in_edges(node, keys=True):
295 unit_id = Gnx.edges[e]["unit_id"]
296 if unit_id in self.q_inputs:
297 qubits.append(unit_id)
299 Gqcnx.add_edges_from(combinations(qubits, 2))
300 G = gv.Graph(
301 "Qubit connectivity",
302 node_attr={
303 "shape": "circle",
304 "color": "blue",
305 "fontname": "Courier",
306 "fontsize": "10",
307 },
308 engine="neato",
309 )
310 G.edges(
311 (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges()
312 )
313 self.Gqc = G
314 return G
316 def view_qubit_graph(self) -> str:
317 """
318 View the qubit connectivity graph.
320 This method creates a temporary file, and returns its filename so that the
321 caller may delete it afterwards.
323 :returns: filename of temporary created file
324 """
325 G = self.get_qubit_graph()
326 filename = NamedTemporaryFile(delete=False).name
327 G.view(filename, quiet=True)
328 return filename
330 def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None:
331 """
332 Save an image of the qubit connectivity graph to a file.
334 The actual filename will be "<name>.<fmt>". A wide range of formats is
335 supported. See https://graphviz.org/doc/info/output.html.
337 :param name: Prefix of file name
338 :type name: str
339 :param fmt: File format, e.g. "pdf", "png", ...
340 :type fmt: str
341 """
342 G = self.get_qubit_graph()
343 G.render(name, cleanup=True, format=fmt, quiet=True)