GCC Code Coverage Report


Directory: ./
File: Predicates/include/Predicates/CompilerPass.hpp
Date: 2022-10-15 05:10:18
Exec Total Coverage
Lines: 12 31 38.7%
Functions: 4 10 40.0%
Branches: 5 28 17.9%
Decisions: 1 4 25.0%

Line Branch Decision Exec Source
1 // Copyright 2019-2022 Cambridge Quantum Computing
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #pragma once
16
17 #include "CompilationUnit.hpp"
18 #include "Predicates.hpp"
19 #include "Utils/Json.hpp"
20
21 namespace tket {
22
23 enum class Guarantee;
24 struct PostConditions;
25 class BasePass;
26 class StandardPass;
27 class SequencePass;
28 class RepeatPass;
29 typedef std::shared_ptr<BasePass> PassPtr;
30 typedef std::map<std::type_index, Guarantee> PredicateClassGuarantees;
31 typedef std::pair<PredicatePtrMap, PostConditions> PassConditions;
32 typedef std::function<void(const CompilationUnit&, const nlohmann::json&)>
33 PassCallback;
34
35 JSON_DECL(PassPtr)
36
37 class IncompatibleCompilerPasses : public std::logic_error {
38 public:
39 6 explicit IncompatibleCompilerPasses(const std::type_index& typeid1)
40 6 : std::logic_error(
41 "Cannot compose these Compiler Passes due to mismatching "
42 6 "Predicates of type: " +
43
1/2
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
12 predicate_name(typeid1)) {}
44 };
45
46 class PassNotSerializable : public std::logic_error {
47 public:
48 explicit PassNotSerializable(const std::string& pass_name)
49 : std::logic_error("Pass not serializable: " + pass_name) {}
50 };
51
52 // dictate how a `CompilerPass` should affect a class of Predicates
53 // PredicateClassGuarantees can't `Set` Predicates, they can only Clear or
54 // Preserve
55 enum class Guarantee { Clear, Preserve };
56
57 enum class SafetyMode {
58 Audit, /* Check after every pass that the cache is being updated correctly */
59 Default, /* Check composition and check precons and postcons */
60 Off /* Check only composition (this should perhaps be removed as well) */
61 };
62
63 /* Guarantee composition for CompilerPass composition:
64 Given A, B, ..., Z CompilerPasss, the Guarantee of Predicate 'p' for A >> B
65 >> ... >> Z is the last non-"Preserve" Guarantee for 'p' in list. If all
66 "Preserve", Guarantee 'p' := "Preserve"
67 */
68
69 // Priority hierarchy: 1) Specific, 2) Generic, 3) Default
70 struct PostConditions {
71 PredicatePtrMap specific_postcons_;
72 PredicateClassGuarantees generic_postcons_;
73 Guarantee default_postcon_;
74 PostConditions(
75 const PredicatePtrMap& specific_postcons = {},
76 const PredicateClassGuarantees& generic_postcons = {},
77 Guarantee default_postcon = Guarantee::Clear)
78 : specific_postcons_(specific_postcons),
79 generic_postcons_(generic_postcons),
80 default_postcon_(default_postcon) {}
81 };
82
83 /**
84 * @brief Default callback when applying a pass (does nothing)
85 */
86 void trivial_callback(const CompilationUnit&, const nlohmann::json&);
87
88 /* Passes are used to generate full sequences of rewrite rules for Circuits. It
89 internally stores pre and postcons which are composed together. Whenever a
90 CompilationUnit is passed through a Pass it has its cache of Predicates
91 updated accordingly. */
92
93 // Passes can be run in safe or unsafe mode (bool flag to dictate).
94 // Safe := runs Predicates at every StandardPass to make sure cache is being
95 // updated correctly Unsafe := When possible, updates the cache without running
96 // Predicates
97 // TODO: Super Unsafe AKA Cowabunga Mode := check nothing, allow everything
98
99 class BasePass {
100 public:
101
1/2
✓ Branch 4 taken 105 times.
✗ Branch 5 not taken.
105 BasePass() {}
102
103 /**
104 * @brief Apply the pass and invoke callbacks
105 * @param c_unit
106 * @param before_apply Called at the start of the apply procedure.
107 * The parameters are the CompilationUnit and a summary of the pass
108 * configuration.
109 * @param after_apply Called at the end of the apply procedure.
110 * The parameters are the CompilationUnit and a summary of the pass
111 * configuration.
112 * @param safe_mode
113 * @return True if pass modified the circuit, else False
114 */
115 virtual bool apply(
116 CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default,
117 const PassCallback& before_apply = trivial_callback,
118 const PassCallback& after_apply = trivial_callback) const = 0;
119
120 friend PassPtr operator>>(const PassPtr& lhs, const PassPtr& rhs);
121
122 virtual std::string to_string() const = 0;
123
124 /**
125 * @brief Get the config object
126 *
127 * @return json containing the name, and params for the pass.
128 */
129 virtual nlohmann::json get_config() const = 0;
130 PassConditions get_conditions() const;
131 Guarantee get_guarantee(const std::type_index& ti) const;
132 static Guarantee get_guarantee(
133 const std::type_index& ti, const PassConditions& conditions);
134
135 virtual ~BasePass(){};
136
137 protected:
138 BasePass(const PredicatePtrMap& precons, const PostConditions& postcons)
139 : precons_(precons), postcons_(postcons) {}
140 PredicatePtrMap precons_;
141 PostConditions postcons_;
142
143 /**
144 * Check whether any preconditions of the compilation unit are unsatisfied.
145 *
146 * Returns the first unsatisfied precondition found.
147 *
148 * @param c_unit compilation unit
149 * @param[in] safe_mode safety mode
150 *
151 * @return unsatisfied precondition, if any
152 */
153 std::optional<PredicatePtr> unsatisfied_precondition(
154 const CompilationUnit& c_unit, SafetyMode safe_mode) const;
155
156 void update_cache(const CompilationUnit& c_unit, SafetyMode safe_mode) const;
157 static PassConditions match_passes(const PassPtr& lhs, const PassPtr& rhs);
158 static PassConditions match_passes(
159 const PassConditions& lhs, const PassConditions& rhs);
160 };
161
162 /* Basic Pass that all combinators can be used on */
163 class StandardPass : public BasePass {
164 public:
165 /**
166 * @brief Construct a new StandardPass object with info about the pass.
167 *
168 * @param precons
169 * @param trans
170 * @param postcons
171 * @param pass_config A nlohmann::json object containing the name, and params
172 * for the pass.
173 */
174 StandardPass(
175 const PredicatePtrMap& precons, const Transform& trans,
176 const PostConditions& postcons, const nlohmann::json& pass_config)
177 : BasePass(precons, postcons), trans_(trans), pass_config_(pass_config) {}
178
179 bool apply(
180 CompilationUnit& c_unit, SafetyMode = SafetyMode::Default,
181 const PassCallback& before_apply = trivial_callback,
182 const PassCallback& after_apply = trivial_callback) const override;
183 std::string to_string() const override;
184 nlohmann::json get_config() const override;
185
186 private:
187 Transform trans_;
188 nlohmann::json pass_config_ = "{\"name\": \"StandardPass\"}"_json;
189 };
190
191 /* Runs a sequence of Passes */
192 class SequencePass : public BasePass {
193 public:
194 59 SequencePass() {}
195 explicit SequencePass(const std::vector<PassPtr>& ptvec);
196 bool apply(
197 CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default,
198 const PassCallback& before_apply = trivial_callback,
199 const PassCallback& after_apply = trivial_callback) const override {
200 before_apply(c_unit, this->get_config());
201 bool success = false;
202
0/2
✗ Decision 'true' not taken.
✗ Decision 'false' not taken.
for (const PassPtr& b : seq_)
203 success |= b->apply(c_unit, safe_mode, before_apply, after_apply);
204 after_apply(c_unit, this->get_config());
205 return success;
206 }
207 std::string to_string() const override;
208 nlohmann::json get_config() const override;
209 std::vector<PassPtr> get_sequence() const { return seq_; }
210
211 friend PassPtr operator>>(const PassPtr& lhs, const PassPtr& rhs);
212
213 private:
214 std::vector<PassPtr> seq_;
215 };
216
217 /* Repeats a Pass until it returns `false` */
218 class RepeatPass : public BasePass {
219 public:
220 explicit RepeatPass(const PassPtr& pass);
221 1 bool apply(
222 CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default,
223 const PassCallback& before_apply = trivial_callback,
224 const PassCallback& after_apply = trivial_callback) const override {
225
1/2
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 before_apply(c_unit, this->get_config());
226 1 bool success = false;
227
1/2
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
1/2
✗ Decision 'true' not taken.
✓ Decision 'false' taken 1 times.
1 while (pass_->apply(c_unit, safe_mode, before_apply, after_apply))
228 success = true;
229
1/2
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 after_apply(c_unit, this->get_config());
230 1 return success;
231 }
232 std::string to_string() const override;
233 nlohmann::json get_config() const override;
234 PassPtr get_pass() const { return pass_; }
235
236 private:
237 PassPtr pass_;
238 };
239
240 class RepeatWithMetricPass : public BasePass {
241 public:
242 RepeatWithMetricPass(const PassPtr& pass, const Transform::Metric& metric);
243 bool apply(
244 CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default,
245 const PassCallback& before_apply = trivial_callback,
246 const PassCallback& after_apply = trivial_callback) const override;
247 std::string to_string() const override;
248 nlohmann::json get_config() const override;
249 PassPtr get_pass() const { return pass_; }
250 Transform::Metric get_metric() const { return metric_; }
251
252 private:
253 PassPtr pass_;
254 Transform::Metric metric_;
255 };
256
257 class RepeatUntilSatisfiedPass : public BasePass {
258 public:
259 RepeatUntilSatisfiedPass(const PassPtr& pass, const PredicatePtr& to_satisfy);
260 RepeatUntilSatisfiedPass(
261 const PassPtr& pass, const std::function<bool(const Circuit&)>& func) {
262 PredicatePtr custom_pred = std::make_shared<UserDefinedPredicate>(func);
263 *this = RepeatUntilSatisfiedPass(pass, custom_pred);
264 }
265 /* Careful: If the predicate is never satisfied this will not terminate */
266 bool apply(
267 CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default,
268 const PassCallback& before_apply = trivial_callback,
269 const PassCallback& after_apply = trivial_callback) const override;
270 std::string to_string() const override;
271 nlohmann::json get_config() const override;
272 PassPtr get_pass() const { return pass_; }
273 PredicatePtr get_predicate() const { return pred_; }
274
275 private:
276 PassPtr pass_;
277 PredicatePtr pred_;
278 };
279
280 // TODO: Repeat with a metric, repeat until a Predicate is satisfied...
281
282 } // namespace tket
283