Coverage for /home/runner/work/tket/tket/pytket/pytket/circuit/decompose_classical.py: 83%
213 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-02 12:44 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-02 12:44 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Functions for decomposing Circuits containing classical expressions
16in to primitive logical operations."""
18from heapq import heappop, heappush
19from typing import Any, Generic, TypeVar
21from pytket._tket.circuit import (
22 Circuit,
23 ClBitVar,
24 ClExpr,
25 ClExprOp,
26 ClOp,
27 ClRegVar,
28 Conditional,
29 OpType,
30 WiredClExpr,
31)
32from pytket._tket.unit_id import (
33 _TEMP_BIT_NAME,
34 _TEMP_BIT_REG_BASE,
35 _TEMP_REG_SIZE,
36 Bit,
37 BitRegister,
38)
39from pytket.circuit.clexpr import check_register_alignments, has_reg_output
40from pytket.circuit.logic_exp import Constant, Variable
42T = TypeVar("T")
45class DecomposeClassicalError(Exception):
46 """Error with decomposing classical operations."""
49class VarHeap(Generic[T]):
50 """A generic heap implementation."""
52 def __init__(self) -> None:
53 self._heap: list[T] = []
54 self._heap_vars: set[T] = set()
56 def pop(self) -> T:
57 """Pop from top of heap."""
58 return heappop(self._heap)
60 def push(self, var: T) -> None:
61 """Push var to heap."""
62 heappush(self._heap, var)
63 self._heap_vars.add(var)
65 def is_heap_var(self, var: T) -> bool:
66 """Check if var was generated from heap."""
67 return var in self._heap_vars
69 def fresh_var(self) -> T:
70 """Generate new variable."""
71 raise NotImplementedError
74class BitHeap(VarHeap[Bit]):
75 """Heap of temporary Bits."""
77 def __init__(self, _reg_name: str = _TEMP_BIT_NAME):
78 """Initialise new BitHeap.
80 :param _reg_name: Name for register of Bits, defaults to _TEMP_BIT_NAME
81 """
83 self.reg_name = _reg_name
84 super().__init__()
86 @property
87 def next_index(self) -> int:
88 """Next available bit index, not used by any other heap bit."""
89 return max((b.index[0] for b in self._heap_vars), default=-1) + 1
91 def fresh_var(self) -> Bit:
92 """Return Bit, from heap if available, otherwise create new."""
93 if self._heap:
94 return self.pop()
95 new_bit = Bit(self.reg_name, self.next_index)
96 self._heap_vars.add(new_bit)
97 return new_bit
100class RegHeap(VarHeap[BitRegister]):
101 """Heap of temporary BitRegisters."""
103 def __init__(self, _reg_name_base: str = _TEMP_BIT_REG_BASE):
104 """Initialise new RegHeap.
106 :param _reg_name_base: base string for register names, defaults to
107 _TEMP_BIT_REG_BASE
108 """
109 self._reg_name_base = _reg_name_base
110 super().__init__()
112 @property
113 def next_index(self) -> int:
114 """Next available bit index, not used by any other heap register."""
115 return (
116 max((int(b.name.split("_")[-1]) for b in self._heap_vars), default=-1) + 1
117 )
119 def fresh_var(self, size: int = _TEMP_REG_SIZE) -> BitRegister:
120 """Return BitRegister, from heap if available, otherwise create new.
121 Optionally set size of created register."""
122 if self._heap:
123 return self.pop()
124 new_reg = BitRegister(f"{self._reg_name_base}_{self.next_index}", size)
125 self._heap_vars.add(new_reg)
127 return new_reg
130def temp_reg_in_args(args: list[Bit]) -> BitRegister | None:
131 """If there are bits from a temporary register in the args, return it."""
132 temp_reg_bits = [b for b in args if b.reg_name.startswith(_TEMP_BIT_REG_BASE)]
133 if temp_reg_bits:
134 return BitRegister(temp_reg_bits[0].reg_name, _TEMP_REG_SIZE)
135 return None
138VarType = TypeVar("VarType", type[Bit], type[BitRegister])
141def _int_to_bools(val: Constant, width: int) -> list[bool]:
142 # map int to bools via litle endian encoding
143 return list(map(bool, map(int, reversed(f"{val:0{width}b}"[-width:]))))
146def _get_bit_width(x: int) -> int:
147 assert x >= 0
148 c = 0
149 while x:
150 x >>= 1
151 c += 1
152 return c
155class _ClExprDecomposer:
156 def __init__( # noqa: PLR0913
157 self,
158 circ: Circuit,
159 bit_posn: dict[int, int],
160 reg_posn: dict[int, list[int]],
161 args: list[Bit],
162 bit_heap: BitHeap,
163 reg_heap: RegHeap,
164 kwargs: dict[str, Any],
165 ):
166 self.circ: Circuit = circ
167 self.bit_posn: dict[int, int] = bit_posn
168 self.reg_posn: dict[int, list[int]] = reg_posn
169 self.args: list[Bit] = args
170 self.bit_heap: BitHeap = bit_heap
171 self.reg_heap: RegHeap = reg_heap
172 self.kwargs: dict[str, Any] = kwargs
173 # Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar)
174 # to BitRegister:
175 self.bit_vars = {i: args[p] for i, p in bit_posn.items()}
176 self.reg_vars = {
177 i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items()
178 }
180 def add_var(self, var: Variable) -> None:
181 """Add a Bit or BitRegister to the circuit if not already present."""
182 if isinstance(var, Bit):
183 self.circ.add_bit(var, reject_dups=False)
184 else:
185 assert isinstance(var, BitRegister)
186 for bit in var.to_list():
187 self.circ.add_bit(bit, reject_dups=False)
189 def set_bits(self, var: Variable, val: int) -> None:
190 """Set the value of a Bit or BitRegister."""
191 assert val >= 0
192 if isinstance(var, Bit):
193 assert val >> 1 == 0
194 self.circ.add_c_setbits([bool(val)], [var], **self.kwargs)
195 else:
196 assert isinstance(var, BitRegister)
197 assert val >> var.size == 0
198 self.circ.add_c_setreg(val, var, **self.kwargs)
200 def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable: # noqa: PLR0912, PLR0915
201 """Add the decomposed expression to the circuit and return the Bit or
202 BitRegister that contains the result.
204 :param expr: the expression to decompose
205 :param out_var: where to put the output (if None, create a new scratch location)
206 """
207 op: ClOp = expr.op
208 heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap
210 # Eliminate (recursively) subsidiary expressions from the arguments, and convert
211 # all terms to Bit or BitRegister:
212 terms: list[Variable] = []
213 for arg in expr.args:
214 if isinstance(arg, int):
215 # Assign to a fresh variable
216 fresh_var = heap.fresh_var()
217 self.add_var(fresh_var)
218 self.set_bits(fresh_var, arg)
219 terms.append(fresh_var)
220 elif isinstance(arg, ClBitVar):
221 terms.append(self.bit_vars[arg.index])
222 elif isinstance(arg, ClRegVar):
223 terms.append(self.reg_vars[arg.index])
224 else:
225 assert isinstance(arg, ClExpr)
226 terms.append(self.decompose_expr(arg, None))
228 # Enable reuse of temporary terms:
229 for term in terms:
230 if heap.is_heap_var(term):
231 heap.push(term)
233 if out_var is None:
234 out_var = heap.fresh_var()
235 self.add_var(out_var)
236 match op:
237 case ClOp.BitAnd:
238 self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore
239 case ClOp.BitNot: 239 ↛ 240line 239 didn't jump to line 240 because the pattern on line 239 never matched
240 self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore
241 case ClOp.BitOne: 241 ↛ 242line 241 didn't jump to line 242 because the pattern on line 241 never matched
242 assert isinstance(out_var, Bit)
243 self.circ.add_c_setbits([True], [out_var], **self.kwargs)
244 case ClOp.BitOr:
245 self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore
246 case ClOp.BitXor:
247 self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore
248 case ClOp.BitZero: 248 ↛ 249line 248 didn't jump to line 249 because the pattern on line 248 never matched
249 assert isinstance(out_var, Bit)
250 self.circ.add_c_setbits([False], [out_var], **self.kwargs)
251 case ClOp.RegAnd:
252 self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore
253 case ClOp.RegNot: 253 ↛ 254line 253 didn't jump to line 254 because the pattern on line 253 never matched
254 self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore
255 case ClOp.RegOne: 255 ↛ 256line 255 didn't jump to line 256 because the pattern on line 255 never matched
256 assert isinstance(out_var, BitRegister)
257 self.circ.add_c_setbits(
258 [True] * out_var.size, out_var.to_list(), **self.kwargs
259 )
260 case ClOp.RegOr:
261 self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore
262 case ClOp.RegXor:
263 self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore
264 case ClOp.RegZero: 264 ↛ 265line 264 didn't jump to line 265 because the pattern on line 264 never matched
265 assert isinstance(out_var, BitRegister)
266 self.circ.add_c_setbits(
267 [False] * out_var.size, out_var.to_list(), **self.kwargs
268 )
269 case _:
270 raise DecomposeClassicalError(
271 f"{op} cannot be decomposed to TKET primitives."
272 )
273 return out_var
276def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: # noqa: PLR0912, PLR0915
277 """Rewrite a circuit command-wise, decomposing ClExprOp."""
278 if not check_register_alignments(circ): 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true
279 raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.")
280 bit_heap = BitHeap()
281 reg_heap = RegHeap()
282 # add already used heap variables to heaps
283 for b in circ.bits:
284 if b.reg_name == _TEMP_BIT_NAME:
285 bit_heap._heap_vars.add(b) # noqa: SLF001
286 elif b.reg_name.startswith(_TEMP_BIT_REG_BASE):
287 reg_heap._heap_vars.add(BitRegister(b.reg_name, _TEMP_REG_SIZE)) # noqa: SLF001
289 newcirc = Circuit(0, name=circ.name)
291 for qb in circ.qubits:
292 newcirc.add_qubit(qb)
293 for cb in circ.bits:
294 # lose all temporary bits, add back as required later
295 if not (
296 cb.reg_name.startswith(_TEMP_BIT_NAME)
297 or cb.reg_name.startswith(_TEMP_BIT_REG_BASE)
298 ):
299 newcirc.add_bit(cb)
301 # targets of predicates that need to be relabelled
302 replace_targets: dict[Variable, Variable] = {}
303 modified = False
304 for command in circ:
305 op = command.op
306 optype = op.type
307 args = command.args
308 kwargs = {}
309 if optype == OpType.Conditional:
310 assert isinstance(op, Conditional)
311 bits = args[: op.width]
312 # check if conditional on previously decomposed expression
313 if len(bits) == 1 and bits[0] in replace_targets: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true
314 assert isinstance(bits[0], Bit)
315 # this op should encode comparison and value
316 assert op.value in (0, 1)
317 replace_bit = replace_targets[bits[0]]
318 # temporary condition bit is available for reuse
319 bit_heap.push(replace_bit) # type: ignore
321 # write new conditional op
322 kwargs = {"condition_bits": [replace_bit], "condition_value": op.value}
323 else:
324 kwargs = {"condition_bits": bits, "condition_value": op.value}
325 args = args[op.width :]
326 op = op.op
327 optype = op.type
329 if optype == OpType.RangePredicate:
330 target = args[-1]
331 assert isinstance(target, Bit)
332 newcirc.add_bit(target, reject_dups=False)
333 temp_reg = temp_reg_in_args(args) # type: ignore
334 # ensure predicate is reading from correct output register
335 if temp_reg in replace_targets: 335 ↛ 336line 335 didn't jump to line 336 because the condition on line 335 was never true
336 assert temp_reg is not None
337 new_target = replace_targets[temp_reg]
338 for i, a in enumerate(args):
339 if a.reg_name == temp_reg.name:
340 args[i] = Bit(new_target.name, a.index[0]) # type: ignore
341 # operations conditional on this bit should remain so
342 replace_targets[target] = target
344 elif optype == OpType.ClExpr:
345 assert isinstance(op, ClExprOp)
346 wexpr: WiredClExpr = op.expr
347 expr: ClExpr = wexpr.expr
348 bit_posn = wexpr.bit_posn
349 reg_posn = wexpr.reg_posn
350 output_posn = wexpr.output_posn
351 assert len(output_posn) > 0
352 output0 = args[output_posn[0]]
353 assert isinstance(output0, Bit)
354 out_var: Variable = (
355 BitRegister(output0.reg_name, len(output_posn))
356 if has_reg_output(expr.op)
357 else output0
358 )
359 decomposer = _ClExprDecomposer(
360 newcirc,
361 bit_posn,
362 reg_posn,
363 args, # type: ignore
364 bit_heap,
365 reg_heap,
366 kwargs,
367 )
368 comp_var = decomposer.decompose_expr(expr, out_var)
369 if comp_var != out_var: 369 ↛ 370line 369 didn't jump to line 370 because the condition on line 369 was never true
370 replace_targets[out_var] = comp_var
371 modified = True
372 continue
374 if optype == OpType.Barrier: 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was never true
375 # add_gate doesn't work for metaops
376 newcirc.add_barrier(args)
377 else:
378 for arg in args:
379 if isinstance(arg, Bit) and arg not in newcirc.bits: 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true
380 newcirc.add_bit(arg)
381 newcirc.add_gate(op, args, **kwargs)
382 return newcirc, modified