Line | Branch | Decision | Exec | Source |
---|---|---|---|---|
1 | // Copyright 2019-2022 Cambridge Quantum Computing | |||
2 | // | |||
3 | // Licensed under the Apache License, Version 2.0 (the "License"); | |||
4 | // you may not use this file except in compliance with the License. | |||
5 | // You may obtain a copy of the License at | |||
6 | // | |||
7 | // http://www.apache.org/licenses/LICENSE-2.0 | |||
8 | // | |||
9 | // Unless required by applicable law or agreed to in writing, software | |||
10 | // distributed under the License is distributed on an "AS IS" BASIS, | |||
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
12 | // See the License for the specific language governing permissions and | |||
13 | // limitations under the License. | |||
14 | ||||
15 | #pragma once | |||
16 | ||||
17 | #include "CompilationUnit.hpp" | |||
18 | #include "Predicates.hpp" | |||
19 | #include "Utils/Json.hpp" | |||
20 | ||||
21 | namespace tket { | |||
22 | ||||
23 | enum class Guarantee; | |||
24 | struct PostConditions; | |||
25 | class BasePass; | |||
26 | class StandardPass; | |||
27 | class SequencePass; | |||
28 | class RepeatPass; | |||
29 | typedef std::shared_ptr<BasePass> PassPtr; | |||
30 | typedef std::map<std::type_index, Guarantee> PredicateClassGuarantees; | |||
31 | typedef std::pair<PredicatePtrMap, PostConditions> PassConditions; | |||
32 | typedef std::function<void(const CompilationUnit&, const nlohmann::json&)> | |||
33 | PassCallback; | |||
34 | ||||
35 | JSON_DECL(PassPtr) | |||
36 | ||||
37 | class IncompatibleCompilerPasses : public std::logic_error { | |||
38 | public: | |||
39 | 6 | explicit IncompatibleCompilerPasses(const std::type_index& typeid1) | ||
40 | 6 | : std::logic_error( | ||
41 | "Cannot compose these Compiler Passes due to mismatching " | |||
42 | 6 | "Predicates of type: " + | ||
43 |
1/2✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
|
12 | predicate_name(typeid1)) {} | |
44 | }; | |||
45 | ||||
46 | class PassNotSerializable : public std::logic_error { | |||
47 | public: | |||
48 | ✗ | explicit PassNotSerializable(const std::string& pass_name) | ||
49 | ✗ | : std::logic_error("Pass not serializable: " + pass_name) {} | ||
50 | }; | |||
51 | ||||
52 | // dictate how a `CompilerPass` should affect a class of Predicates | |||
53 | // PredicateClassGuarantees can't `Set` Predicates, they can only Clear or | |||
54 | // Preserve | |||
55 | enum class Guarantee { Clear, Preserve }; | |||
56 | ||||
57 | enum class SafetyMode { | |||
58 | Audit, /* Check after every pass that the cache is being updated correctly */ | |||
59 | Default, /* Check composition and check precons and postcons */ | |||
60 | Off /* Check only composition (this should perhaps be removed as well) */ | |||
61 | }; | |||
62 | ||||
63 | /* Guarantee composition for CompilerPass composition: | |||
64 | Given A, B, ..., Z CompilerPasss, the Guarantee of Predicate 'p' for A >> B | |||
65 | >> ... >> Z is the last non-"Preserve" Guarantee for 'p' in list. If all | |||
66 | "Preserve", Guarantee 'p' := "Preserve" | |||
67 | */ | |||
68 | ||||
69 | // Priority hierarchy: 1) Specific, 2) Generic, 3) Default | |||
70 | struct PostConditions { | |||
71 | PredicatePtrMap specific_postcons_; | |||
72 | PredicateClassGuarantees generic_postcons_; | |||
73 | Guarantee default_postcon_; | |||
74 | ✗ | PostConditions( | ||
75 | const PredicatePtrMap& specific_postcons = {}, | |||
76 | const PredicateClassGuarantees& generic_postcons = {}, | |||
77 | Guarantee default_postcon = Guarantee::Clear) | |||
78 | ✗ | : specific_postcons_(specific_postcons), | ||
79 | ✗ | generic_postcons_(generic_postcons), | ||
80 | ✗ | default_postcon_(default_postcon) {} | ||
81 | }; | |||
82 | ||||
83 | /** | |||
84 | * @brief Default callback when applying a pass (does nothing) | |||
85 | */ | |||
86 | void trivial_callback(const CompilationUnit&, const nlohmann::json&); | |||
87 | ||||
88 | /* Passes are used to generate full sequences of rewrite rules for Circuits. It | |||
89 | internally stores pre and postcons which are composed together. Whenever a | |||
90 | CompilationUnit is passed through a Pass it has its cache of Predicates | |||
91 | updated accordingly. */ | |||
92 | ||||
93 | // Passes can be run in safe or unsafe mode (bool flag to dictate). | |||
94 | // Safe := runs Predicates at every StandardPass to make sure cache is being | |||
95 | // updated correctly Unsafe := When possible, updates the cache without running | |||
96 | // Predicates | |||
97 | // TODO: Super Unsafe AKA Cowabunga Mode := check nothing, allow everything | |||
98 | ||||
99 | class BasePass { | |||
100 | public: | |||
101 |
1/2✓ Branch 4 taken 105 times.
✗ Branch 5 not taken.
|
105 | BasePass() {} | |
102 | ||||
103 | /** | |||
104 | * @brief Apply the pass and invoke callbacks | |||
105 | * @param c_unit | |||
106 | * @param before_apply Called at the start of the apply procedure. | |||
107 | * The parameters are the CompilationUnit and a summary of the pass | |||
108 | * configuration. | |||
109 | * @param after_apply Called at the end of the apply procedure. | |||
110 | * The parameters are the CompilationUnit and a summary of the pass | |||
111 | * configuration. | |||
112 | * @param safe_mode | |||
113 | * @return True if pass modified the circuit, else False | |||
114 | */ | |||
115 | virtual bool apply( | |||
116 | CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default, | |||
117 | const PassCallback& before_apply = trivial_callback, | |||
118 | const PassCallback& after_apply = trivial_callback) const = 0; | |||
119 | ||||
120 | friend PassPtr operator>>(const PassPtr& lhs, const PassPtr& rhs); | |||
121 | ||||
122 | virtual std::string to_string() const = 0; | |||
123 | ||||
124 | /** | |||
125 | * @brief Get the config object | |||
126 | * | |||
127 | * @return json containing the name, and params for the pass. | |||
128 | */ | |||
129 | virtual nlohmann::json get_config() const = 0; | |||
130 | PassConditions get_conditions() const; | |||
131 | Guarantee get_guarantee(const std::type_index& ti) const; | |||
132 | static Guarantee get_guarantee( | |||
133 | const std::type_index& ti, const PassConditions& conditions); | |||
134 | ||||
135 | ✗ | virtual ~BasePass(){}; | ||
136 | ||||
137 | protected: | |||
138 | ✗ | BasePass(const PredicatePtrMap& precons, const PostConditions& postcons) | ||
139 | ✗ | : precons_(precons), postcons_(postcons) {} | ||
140 | PredicatePtrMap precons_; | |||
141 | PostConditions postcons_; | |||
142 | ||||
143 | /** | |||
144 | * Check whether any preconditions of the compilation unit are unsatisfied. | |||
145 | * | |||
146 | * Returns the first unsatisfied precondition found. | |||
147 | * | |||
148 | * @param c_unit compilation unit | |||
149 | * @param[in] safe_mode safety mode | |||
150 | * | |||
151 | * @return unsatisfied precondition, if any | |||
152 | */ | |||
153 | std::optional<PredicatePtr> unsatisfied_precondition( | |||
154 | const CompilationUnit& c_unit, SafetyMode safe_mode) const; | |||
155 | ||||
156 | void update_cache(const CompilationUnit& c_unit, SafetyMode safe_mode) const; | |||
157 | static PassConditions match_passes(const PassPtr& lhs, const PassPtr& rhs); | |||
158 | static PassConditions match_passes( | |||
159 | const PassConditions& lhs, const PassConditions& rhs); | |||
160 | }; | |||
161 | ||||
162 | /* Basic Pass that all combinators can be used on */ | |||
163 | class StandardPass : public BasePass { | |||
164 | public: | |||
165 | /** | |||
166 | * @brief Construct a new StandardPass object with info about the pass. | |||
167 | * | |||
168 | * @param precons | |||
169 | * @param trans | |||
170 | * @param postcons | |||
171 | * @param pass_config A nlohmann::json object containing the name, and params | |||
172 | * for the pass. | |||
173 | */ | |||
174 | ✗ | StandardPass( | ||
175 | const PredicatePtrMap& precons, const Transform& trans, | |||
176 | const PostConditions& postcons, const nlohmann::json& pass_config) | |||
177 | ✗ | : BasePass(precons, postcons), trans_(trans), pass_config_(pass_config) {} | ||
178 | ||||
179 | bool apply( | |||
180 | CompilationUnit& c_unit, SafetyMode = SafetyMode::Default, | |||
181 | const PassCallback& before_apply = trivial_callback, | |||
182 | const PassCallback& after_apply = trivial_callback) const override; | |||
183 | std::string to_string() const override; | |||
184 | nlohmann::json get_config() const override; | |||
185 | ||||
186 | private: | |||
187 | Transform trans_; | |||
188 | nlohmann::json pass_config_ = "{\"name\": \"StandardPass\"}"_json; | |||
189 | }; | |||
190 | ||||
191 | /* Runs a sequence of Passes */ | |||
192 | class SequencePass : public BasePass { | |||
193 | public: | |||
194 | 59 | SequencePass() {} | ||
195 | explicit SequencePass(const std::vector<PassPtr>& ptvec); | |||
196 | ✗ | bool apply( | ||
197 | CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default, | |||
198 | const PassCallback& before_apply = trivial_callback, | |||
199 | const PassCallback& after_apply = trivial_callback) const override { | |||
200 | ✗ | before_apply(c_unit, this->get_config()); | ||
201 | ✗ | bool success = false; | ||
202 |
0/2✗ Decision 'true' not taken.
✗ Decision 'false' not taken.
|
✗ | for (const PassPtr& b : seq_) | |
203 | ✗ | success |= b->apply(c_unit, safe_mode, before_apply, after_apply); | ||
204 | ✗ | after_apply(c_unit, this->get_config()); | ||
205 | ✗ | return success; | ||
206 | } | |||
207 | std::string to_string() const override; | |||
208 | nlohmann::json get_config() const override; | |||
209 | std::vector<PassPtr> get_sequence() const { return seq_; } | |||
210 | ||||
211 | friend PassPtr operator>>(const PassPtr& lhs, const PassPtr& rhs); | |||
212 | ||||
213 | private: | |||
214 | std::vector<PassPtr> seq_; | |||
215 | }; | |||
216 | ||||
217 | /* Repeats a Pass until it returns `false` */ | |||
218 | class RepeatPass : public BasePass { | |||
219 | public: | |||
220 | explicit RepeatPass(const PassPtr& pass); | |||
221 | 1 | bool apply( | ||
222 | CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default, | |||
223 | const PassCallback& before_apply = trivial_callback, | |||
224 | const PassCallback& after_apply = trivial_callback) const override { | |||
225 |
1/2✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | before_apply(c_unit, this->get_config()); | |
226 | 1 | bool success = false; | ||
227 |
1/2✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
|
1/2✗ Decision 'true' not taken.
✓ Decision 'false' taken 1 times.
|
1 | while (pass_->apply(c_unit, safe_mode, before_apply, after_apply)) |
228 | ✗ | success = true; | ||
229 |
1/2✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | after_apply(c_unit, this->get_config()); | |
230 | 1 | return success; | ||
231 | } | |||
232 | std::string to_string() const override; | |||
233 | nlohmann::json get_config() const override; | |||
234 | PassPtr get_pass() const { return pass_; } | |||
235 | ||||
236 | private: | |||
237 | PassPtr pass_; | |||
238 | }; | |||
239 | ||||
240 | class RepeatWithMetricPass : public BasePass { | |||
241 | public: | |||
242 | RepeatWithMetricPass(const PassPtr& pass, const Transform::Metric& metric); | |||
243 | bool apply( | |||
244 | CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default, | |||
245 | const PassCallback& before_apply = trivial_callback, | |||
246 | const PassCallback& after_apply = trivial_callback) const override; | |||
247 | std::string to_string() const override; | |||
248 | nlohmann::json get_config() const override; | |||
249 | PassPtr get_pass() const { return pass_; } | |||
250 | Transform::Metric get_metric() const { return metric_; } | |||
251 | ||||
252 | private: | |||
253 | PassPtr pass_; | |||
254 | Transform::Metric metric_; | |||
255 | }; | |||
256 | ||||
257 | class RepeatUntilSatisfiedPass : public BasePass { | |||
258 | public: | |||
259 | RepeatUntilSatisfiedPass(const PassPtr& pass, const PredicatePtr& to_satisfy); | |||
260 | RepeatUntilSatisfiedPass( | |||
261 | const PassPtr& pass, const std::function<bool(const Circuit&)>& func) { | |||
262 | PredicatePtr custom_pred = std::make_shared<UserDefinedPredicate>(func); | |||
263 | *this = RepeatUntilSatisfiedPass(pass, custom_pred); | |||
264 | } | |||
265 | /* Careful: If the predicate is never satisfied this will not terminate */ | |||
266 | bool apply( | |||
267 | CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default, | |||
268 | const PassCallback& before_apply = trivial_callback, | |||
269 | const PassCallback& after_apply = trivial_callback) const override; | |||
270 | std::string to_string() const override; | |||
271 | nlohmann::json get_config() const override; | |||
272 | PassPtr get_pass() const { return pass_; } | |||
273 | PredicatePtr get_predicate() const { return pred_; } | |||
274 | ||||
275 | private: | |||
276 | PassPtr pass_; | |||
277 | PredicatePtr pred_; | |||
278 | }; | |||
279 | ||||
280 | // TODO: Repeat with a metric, repeat until a Predicate is satisfied... | |||
281 | ||||
282 | } // namespace tket | |||
283 |