Coverage for /home/runner/work/tket/tket/pytket/pytket/passes/passselector.py: 100%
24 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15from collections.abc import Callable
17from pytket.circuit import Circuit
19from .._tket.passes import BasePass
22class PassSelector:
23 """
24 Collection of pytket compilation passes which are
25 all applied to the same circuit. The result of the
26 compilation is the best circuit as selected by a given metric.
27 """
29 def __init__(
30 self,
31 passlist: list[BasePass],
32 score_func: Callable[[Circuit], int],
33 ):
34 """
35 Constructs a PassSelector
37 :param passlist: list of pytket compilation passes
38 :param score_func: function to score the
39 results of the compilation (lower scores are preferred)
40 """
41 self._passlist = passlist
42 self._score_func = score_func
43 if len(self._passlist) < 1:
44 raise ValueError("passlist needs to contain at least one pass")
46 def apply(self, circ: Circuit) -> Circuit:
47 """
48 Compiles the given circuit with the best of the given passes.
50 :param circ: Circuit that should be compiled
51 :return: compiled circuit
52 """
53 circ_list = [circ.copy() for _ in self._passlist]
55 self._scores: list[int | None] = []
57 for p, c in zip(self._passlist, circ_list):
58 try:
59 p.apply(c)
60 self._scores.append(self._score_func(c))
61 except: # noqa: E722
62 # in case of any error the pass should be ignored
63 self._scores.append(None)
65 try:
66 return circ_list[
67 self._scores.index(min(x for x in self._scores if x is not None))
68 ]
69 except ValueError:
70 raise RuntimeError("No passes have successfully run on this circuit")
72 def get_scores(self) -> list[int | None]:
73 """
74 :return: scores of the circuit after compiling
75 for each of the compilations passes
76 """
77 return self._scores