Coverage for /home/runner/work/tket/tket/pytket/pytket/circuit/decompose_classical.py: 83%
213 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Functions for decomposing Circuits containing classical expressions
16in to primitive logical operations."""
18from heapq import heappop, heappush
19from typing import Any, Generic, TypeVar
21from pytket._tket.circuit import (
22 Circuit,
23 ClBitVar,
24 ClExpr,
25 ClExprOp,
26 ClOp,
27 ClRegVar,
28 Conditional,
29 OpType,
30 WiredClExpr,
31)
32from pytket._tket.unit_id import (
33 _TEMP_BIT_NAME,
34 _TEMP_BIT_REG_BASE,
35 _TEMP_REG_SIZE,
36 Bit,
37 BitRegister,
38)
39from pytket.circuit.clexpr import check_register_alignments, has_reg_output
40from pytket.circuit.logic_exp import Constant, Variable
42T = TypeVar("T")
45class DecomposeClassicalError(Exception):
46 """Error with decomposing classical operations."""
49class VarHeap(Generic[T]):
50 """A generic heap implementation."""
52 def __init__(self) -> None:
53 self._heap: list[T] = []
54 self._heap_vars: set[T] = set()
56 def pop(self) -> T:
57 """Pop from top of heap."""
58 return heappop(self._heap)
60 def push(self, var: T) -> None:
61 """Push var to heap."""
62 heappush(self._heap, var)
63 self._heap_vars.add(var)
65 def is_heap_var(self, var: T) -> bool:
66 """Check if var was generated from heap."""
67 return var in self._heap_vars
69 def fresh_var(self) -> T:
70 """Generate new variable."""
71 raise NotImplementedError
74class BitHeap(VarHeap[Bit]):
75 """Heap of temporary Bits."""
77 def __init__(self, _reg_name: str = _TEMP_BIT_NAME):
78 """Initialise new BitHeap.
80 :param _reg_name: Name for register of Bits, defaults to _TEMP_BIT_NAME
81 :type _reg_name: str, optional
82 """
84 self.reg_name = _reg_name
85 super().__init__()
87 @property
88 def next_index(self) -> int:
89 """Next available bit index, not used by any other heap bit."""
90 return max((b.index[0] for b in self._heap_vars), default=-1) + 1
92 def fresh_var(self) -> Bit:
93 """Return Bit, from heap if available, otherwise create new."""
94 if self._heap:
95 return self.pop()
96 new_bit = Bit(self.reg_name, self.next_index)
97 self._heap_vars.add(new_bit)
98 return new_bit
101class RegHeap(VarHeap[BitRegister]):
102 """Heap of temporary BitRegisters."""
104 def __init__(self, _reg_name_base: str = _TEMP_BIT_REG_BASE):
105 """Initialise new RegHeap.
107 :param _reg_name_base: base string for register names, defaults to
108 _TEMP_BIT_REG_BASE
109 :type _reg_name_base: str, optional
110 """
111 self._reg_name_base = _reg_name_base
112 super().__init__()
114 @property
115 def next_index(self) -> int:
116 """Next available bit index, not used by any other heap register."""
117 return (
118 max((int(b.name.split("_")[-1]) for b in self._heap_vars), default=-1) + 1
119 )
121 def fresh_var(self, size: int = _TEMP_REG_SIZE) -> BitRegister:
122 """Return BitRegister, from heap if available, otherwise create new.
123 Optionally set size of created register."""
124 if self._heap:
125 return self.pop()
126 new_reg = BitRegister(f"{self._reg_name_base}_{self.next_index}", size)
127 self._heap_vars.add(new_reg)
129 return new_reg
132def temp_reg_in_args(args: list[Bit]) -> BitRegister | None:
133 """If there are bits from a temporary register in the args, return it."""
134 temp_reg_bits = [b for b in args if b.reg_name.startswith(_TEMP_BIT_REG_BASE)]
135 if temp_reg_bits:
136 return BitRegister(temp_reg_bits[0].reg_name, _TEMP_REG_SIZE)
137 return None
140VarType = TypeVar("VarType", type[Bit], type[BitRegister])
143def int_to_bools(val: Constant, width: int) -> list[bool]:
144 # map int to bools via litle endian encoding
145 return list(map(bool, map(int, reversed(f"{val:0{width}b}"[-width:]))))
148def get_bit_width(x: int) -> int:
149 assert x >= 0
150 c = 0
151 while x:
152 x >>= 1
153 c += 1
154 return c
157class ClExprDecomposer:
158 def __init__(
159 self,
160 circ: Circuit,
161 bit_posn: dict[int, int],
162 reg_posn: dict[int, list[int]],
163 args: list[Bit],
164 bit_heap: BitHeap,
165 reg_heap: RegHeap,
166 kwargs: dict[str, Any],
167 ):
168 self.circ: Circuit = circ
169 self.bit_posn: dict[int, int] = bit_posn
170 self.reg_posn: dict[int, list[int]] = reg_posn
171 self.args: list[Bit] = args
172 self.bit_heap: BitHeap = bit_heap
173 self.reg_heap: RegHeap = reg_heap
174 self.kwargs: dict[str, Any] = kwargs
175 # Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar)
176 # to BitRegister:
177 self.bit_vars = {i: args[p] for i, p in bit_posn.items()}
178 self.reg_vars = {
179 i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items()
180 }
182 def add_var(self, var: Variable) -> None:
183 """Add a Bit or BitRegister to the circuit if not already present."""
184 if isinstance(var, Bit):
185 self.circ.add_bit(var, reject_dups=False)
186 else:
187 assert isinstance(var, BitRegister)
188 for bit in var.to_list():
189 self.circ.add_bit(bit, reject_dups=False)
191 def set_bits(self, var: Variable, val: int) -> None:
192 """Set the value of a Bit or BitRegister."""
193 assert val >= 0
194 if isinstance(var, Bit):
195 assert val >> 1 == 0
196 self.circ.add_c_setbits([bool(val)], [var], **self.kwargs)
197 else:
198 assert isinstance(var, BitRegister)
199 assert val >> var.size == 0
200 self.circ.add_c_setreg(val, var, **self.kwargs)
202 def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable:
203 """Add the decomposed expression to the circuit and return the Bit or
204 BitRegister that contains the result.
206 :param expr: the expression to decompose
207 :param out_var: where to put the output (if None, create a new scratch location)
208 """
209 op: ClOp = expr.op
210 heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap
212 # Eliminate (recursively) subsidiary expressions from the arguments, and convert
213 # all terms to Bit or BitRegister:
214 terms: list[Variable] = []
215 for arg in expr.args:
216 if isinstance(arg, int):
217 # Assign to a fresh variable
218 fresh_var = heap.fresh_var()
219 self.add_var(fresh_var)
220 self.set_bits(fresh_var, arg)
221 terms.append(fresh_var)
222 elif isinstance(arg, ClBitVar):
223 terms.append(self.bit_vars[arg.index])
224 elif isinstance(arg, ClRegVar):
225 terms.append(self.reg_vars[arg.index])
226 else:
227 assert isinstance(arg, ClExpr)
228 terms.append(self.decompose_expr(arg, None))
230 # Enable reuse of temporary terms:
231 for term in terms:
232 if heap.is_heap_var(term):
233 heap.push(term)
235 if out_var is None:
236 out_var = heap.fresh_var()
237 self.add_var(out_var)
238 match op:
239 case ClOp.BitAnd:
240 self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore
241 case ClOp.BitNot: 241 ↛ 242line 241 didn't jump to line 242 because the pattern on line 241 never matched
242 self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore
243 case ClOp.BitOne: 243 ↛ 244line 243 didn't jump to line 244 because the pattern on line 243 never matched
244 assert isinstance(out_var, Bit)
245 self.circ.add_c_setbits([True], [out_var], **self.kwargs)
246 case ClOp.BitOr:
247 self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore
248 case ClOp.BitXor:
249 self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore
250 case ClOp.BitZero: 250 ↛ 251line 250 didn't jump to line 251 because the pattern on line 250 never matched
251 assert isinstance(out_var, Bit)
252 self.circ.add_c_setbits([False], [out_var], **self.kwargs)
253 case ClOp.RegAnd:
254 self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore
255 case ClOp.RegNot: 255 ↛ 256line 255 didn't jump to line 256 because the pattern on line 255 never matched
256 self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore
257 case ClOp.RegOne: 257 ↛ 258line 257 didn't jump to line 258 because the pattern on line 257 never matched
258 assert isinstance(out_var, BitRegister)
259 self.circ.add_c_setbits(
260 [True] * out_var.size, out_var.to_list(), **self.kwargs
261 )
262 case ClOp.RegOr:
263 self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore
264 case ClOp.RegXor:
265 self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore
266 case ClOp.RegZero: 266 ↛ 267line 266 didn't jump to line 267 because the pattern on line 266 never matched
267 assert isinstance(out_var, BitRegister)
268 self.circ.add_c_setbits(
269 [False] * out_var.size, out_var.to_list(), **self.kwargs
270 )
271 case _:
272 raise DecomposeClassicalError(
273 f"{op} cannot be decomposed to TKET primitives."
274 )
275 return out_var
278def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]:
279 """Rewrite a circuit command-wise, decomposing ClExprOp."""
280 if not check_register_alignments(circ): 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true
281 raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.")
282 bit_heap = BitHeap()
283 reg_heap = RegHeap()
284 # add already used heap variables to heaps
285 for b in circ.bits:
286 if b.reg_name == _TEMP_BIT_NAME:
287 bit_heap._heap_vars.add(b)
288 elif b.reg_name.startswith(_TEMP_BIT_REG_BASE):
289 reg_heap._heap_vars.add(BitRegister(b.reg_name, _TEMP_REG_SIZE))
291 newcirc = Circuit(0, name=circ.name)
293 for qb in circ.qubits:
294 newcirc.add_qubit(qb)
295 for cb in circ.bits:
296 # lose all temporary bits, add back as required later
297 if not (
298 cb.reg_name.startswith(_TEMP_BIT_NAME)
299 or cb.reg_name.startswith(_TEMP_BIT_REG_BASE)
300 ):
301 newcirc.add_bit(cb)
303 # targets of predicates that need to be relabelled
304 replace_targets: dict[Variable, Variable] = dict()
305 modified = False
306 for command in circ:
307 op = command.op
308 optype = op.type
309 args = command.args
310 kwargs = dict()
311 if optype == OpType.Conditional:
312 assert isinstance(op, Conditional)
313 bits = args[: op.width]
314 # check if conditional on previously decomposed expression
315 if len(bits) == 1 and bits[0] in replace_targets: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true
316 assert isinstance(bits[0], Bit)
317 # this op should encode comparison and value
318 assert op.value in (0, 1)
319 replace_bit = replace_targets[bits[0]]
320 # temporary condition bit is available for reuse
321 bit_heap.push(replace_bit) # type: ignore
323 # write new conditional op
324 kwargs = {"condition_bits": [replace_bit], "condition_value": op.value}
325 else:
326 kwargs = {"condition_bits": bits, "condition_value": op.value}
327 args = args[op.width :]
328 op = op.op
329 optype = op.type
331 if optype == OpType.RangePredicate:
332 target = args[-1]
333 assert isinstance(target, Bit)
334 newcirc.add_bit(target, reject_dups=False)
335 temp_reg = temp_reg_in_args(args) # type: ignore
336 # ensure predicate is reading from correct output register
337 if temp_reg in replace_targets: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true
338 assert temp_reg is not None
339 new_target = replace_targets[temp_reg]
340 for i, a in enumerate(args):
341 if a.reg_name == temp_reg.name:
342 args[i] = Bit(new_target.name, a.index[0]) # type: ignore
343 # operations conditional on this bit should remain so
344 replace_targets[target] = target
346 elif optype == OpType.ClExpr:
347 assert isinstance(op, ClExprOp)
348 wexpr: WiredClExpr = op.expr
349 expr: ClExpr = wexpr.expr
350 bit_posn = wexpr.bit_posn
351 reg_posn = wexpr.reg_posn
352 output_posn = wexpr.output_posn
353 assert len(output_posn) > 0
354 output0 = args[output_posn[0]]
355 assert isinstance(output0, Bit)
356 out_var: Variable = (
357 BitRegister(output0.reg_name, len(output_posn))
358 if has_reg_output(expr.op)
359 else output0
360 )
361 decomposer = ClExprDecomposer(
362 newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs # type: ignore
363 )
364 comp_var = decomposer.decompose_expr(expr, out_var)
365 if comp_var != out_var: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true
366 replace_targets[out_var] = comp_var
367 modified = True
368 continue
370 if optype == OpType.Barrier: 370 ↛ 372line 370 didn't jump to line 372 because the condition on line 370 was never true
371 # add_gate doesn't work for metaops
372 newcirc.add_barrier(args)
373 else:
374 for arg in args:
375 if isinstance(arg, Bit) and arg not in newcirc.bits: 375 ↛ 376line 375 didn't jump to line 376 because the condition on line 375 was never true
376 newcirc.add_bit(arg)
377 newcirc.add_gate(op, args, **kwargs)
378 return newcirc, modified