Coverage for /home/runner/work/tket/tket/pytket/pytket/circuit/decompose_classical.py: 83%

213 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-02 12:44 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Functions for decomposing Circuits containing classical expressions 

16in to primitive logical operations.""" 

17 

18from heapq import heappop, heappush 

19from typing import Any, Generic, TypeVar 

20 

21from pytket._tket.circuit import ( 

22 Circuit, 

23 ClBitVar, 

24 ClExpr, 

25 ClExprOp, 

26 ClOp, 

27 ClRegVar, 

28 Conditional, 

29 OpType, 

30 WiredClExpr, 

31) 

32from pytket._tket.unit_id import ( 

33 _TEMP_BIT_NAME, 

34 _TEMP_BIT_REG_BASE, 

35 _TEMP_REG_SIZE, 

36 Bit, 

37 BitRegister, 

38) 

39from pytket.circuit.clexpr import check_register_alignments, has_reg_output 

40from pytket.circuit.logic_exp import Constant, Variable 

41 

42T = TypeVar("T") 

43 

44 

45class DecomposeClassicalError(Exception): 

46 """Error with decomposing classical operations.""" 

47 

48 

49class VarHeap(Generic[T]): 

50 """A generic heap implementation.""" 

51 

52 def __init__(self) -> None: 

53 self._heap: list[T] = [] 

54 self._heap_vars: set[T] = set() 

55 

56 def pop(self) -> T: 

57 """Pop from top of heap.""" 

58 return heappop(self._heap) 

59 

60 def push(self, var: T) -> None: 

61 """Push var to heap.""" 

62 heappush(self._heap, var) 

63 self._heap_vars.add(var) 

64 

65 def is_heap_var(self, var: T) -> bool: 

66 """Check if var was generated from heap.""" 

67 return var in self._heap_vars 

68 

69 def fresh_var(self) -> T: 

70 """Generate new variable.""" 

71 raise NotImplementedError 

72 

73 

74class BitHeap(VarHeap[Bit]): 

75 """Heap of temporary Bits.""" 

76 

77 def __init__(self, _reg_name: str = _TEMP_BIT_NAME): 

78 """Initialise new BitHeap. 

79 

80 :param _reg_name: Name for register of Bits, defaults to _TEMP_BIT_NAME 

81 """ 

82 

83 self.reg_name = _reg_name 

84 super().__init__() 

85 

86 @property 

87 def next_index(self) -> int: 

88 """Next available bit index, not used by any other heap bit.""" 

89 return max((b.index[0] for b in self._heap_vars), default=-1) + 1 

90 

91 def fresh_var(self) -> Bit: 

92 """Return Bit, from heap if available, otherwise create new.""" 

93 if self._heap: 

94 return self.pop() 

95 new_bit = Bit(self.reg_name, self.next_index) 

96 self._heap_vars.add(new_bit) 

97 return new_bit 

98 

99 

100class RegHeap(VarHeap[BitRegister]): 

101 """Heap of temporary BitRegisters.""" 

102 

103 def __init__(self, _reg_name_base: str = _TEMP_BIT_REG_BASE): 

104 """Initialise new RegHeap. 

105 

106 :param _reg_name_base: base string for register names, defaults to 

107 _TEMP_BIT_REG_BASE 

108 """ 

109 self._reg_name_base = _reg_name_base 

110 super().__init__() 

111 

112 @property 

113 def next_index(self) -> int: 

114 """Next available bit index, not used by any other heap register.""" 

115 return ( 

116 max((int(b.name.split("_")[-1]) for b in self._heap_vars), default=-1) + 1 

117 ) 

118 

119 def fresh_var(self, size: int = _TEMP_REG_SIZE) -> BitRegister: 

120 """Return BitRegister, from heap if available, otherwise create new. 

121 Optionally set size of created register.""" 

122 if self._heap: 

123 return self.pop() 

124 new_reg = BitRegister(f"{self._reg_name_base}_{self.next_index}", size) 

125 self._heap_vars.add(new_reg) 

126 

127 return new_reg 

128 

129 

130def temp_reg_in_args(args: list[Bit]) -> BitRegister | None: 

131 """If there are bits from a temporary register in the args, return it.""" 

132 temp_reg_bits = [b for b in args if b.reg_name.startswith(_TEMP_BIT_REG_BASE)] 

133 if temp_reg_bits: 

134 return BitRegister(temp_reg_bits[0].reg_name, _TEMP_REG_SIZE) 

135 return None 

136 

137 

138VarType = TypeVar("VarType", type[Bit], type[BitRegister]) 

139 

140 

141def _int_to_bools(val: Constant, width: int) -> list[bool]: 

142 # map int to bools via litle endian encoding 

143 return list(map(bool, map(int, reversed(f"{val:0{width}b}"[-width:])))) 

144 

145 

146def _get_bit_width(x: int) -> int: 

147 assert x >= 0 

148 c = 0 

149 while x: 

150 x >>= 1 

151 c += 1 

152 return c 

153 

154 

155class _ClExprDecomposer: 

156 def __init__( # noqa: PLR0913 

157 self, 

158 circ: Circuit, 

159 bit_posn: dict[int, int], 

160 reg_posn: dict[int, list[int]], 

161 args: list[Bit], 

162 bit_heap: BitHeap, 

163 reg_heap: RegHeap, 

164 kwargs: dict[str, Any], 

165 ): 

166 self.circ: Circuit = circ 

167 self.bit_posn: dict[int, int] = bit_posn 

168 self.reg_posn: dict[int, list[int]] = reg_posn 

169 self.args: list[Bit] = args 

170 self.bit_heap: BitHeap = bit_heap 

171 self.reg_heap: RegHeap = reg_heap 

172 self.kwargs: dict[str, Any] = kwargs 

173 # Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar) 

174 # to BitRegister: 

175 self.bit_vars = {i: args[p] for i, p in bit_posn.items()} 

176 self.reg_vars = { 

177 i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items() 

178 } 

179 

180 def add_var(self, var: Variable) -> None: 

181 """Add a Bit or BitRegister to the circuit if not already present.""" 

182 if isinstance(var, Bit): 

183 self.circ.add_bit(var, reject_dups=False) 

184 else: 

185 assert isinstance(var, BitRegister) 

186 for bit in var.to_list(): 

187 self.circ.add_bit(bit, reject_dups=False) 

188 

189 def set_bits(self, var: Variable, val: int) -> None: 

190 """Set the value of a Bit or BitRegister.""" 

191 assert val >= 0 

192 if isinstance(var, Bit): 

193 assert val >> 1 == 0 

194 self.circ.add_c_setbits([bool(val)], [var], **self.kwargs) 

195 else: 

196 assert isinstance(var, BitRegister) 

197 assert val >> var.size == 0 

198 self.circ.add_c_setreg(val, var, **self.kwargs) 

199 

200 def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable: # noqa: PLR0912, PLR0915 

201 """Add the decomposed expression to the circuit and return the Bit or 

202 BitRegister that contains the result. 

203 

204 :param expr: the expression to decompose 

205 :param out_var: where to put the output (if None, create a new scratch location) 

206 """ 

207 op: ClOp = expr.op 

208 heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap 

209 

210 # Eliminate (recursively) subsidiary expressions from the arguments, and convert 

211 # all terms to Bit or BitRegister: 

212 terms: list[Variable] = [] 

213 for arg in expr.args: 

214 if isinstance(arg, int): 

215 # Assign to a fresh variable 

216 fresh_var = heap.fresh_var() 

217 self.add_var(fresh_var) 

218 self.set_bits(fresh_var, arg) 

219 terms.append(fresh_var) 

220 elif isinstance(arg, ClBitVar): 

221 terms.append(self.bit_vars[arg.index]) 

222 elif isinstance(arg, ClRegVar): 

223 terms.append(self.reg_vars[arg.index]) 

224 else: 

225 assert isinstance(arg, ClExpr) 

226 terms.append(self.decompose_expr(arg, None)) 

227 

228 # Enable reuse of temporary terms: 

229 for term in terms: 

230 if heap.is_heap_var(term): 

231 heap.push(term) 

232 

233 if out_var is None: 

234 out_var = heap.fresh_var() 

235 self.add_var(out_var) 

236 match op: 

237 case ClOp.BitAnd: 

238 self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore 

239 case ClOp.BitNot: 239 ↛ 240line 239 didn't jump to line 240 because the pattern on line 239 never matched

240 self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore 

241 case ClOp.BitOne: 241 ↛ 242line 241 didn't jump to line 242 because the pattern on line 241 never matched

242 assert isinstance(out_var, Bit) 

243 self.circ.add_c_setbits([True], [out_var], **self.kwargs) 

244 case ClOp.BitOr: 

245 self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore 

246 case ClOp.BitXor: 

247 self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore 

248 case ClOp.BitZero: 248 ↛ 249line 248 didn't jump to line 249 because the pattern on line 248 never matched

249 assert isinstance(out_var, Bit) 

250 self.circ.add_c_setbits([False], [out_var], **self.kwargs) 

251 case ClOp.RegAnd: 

252 self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

253 case ClOp.RegNot: 253 ↛ 254line 253 didn't jump to line 254 because the pattern on line 253 never matched

254 self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

255 case ClOp.RegOne: 255 ↛ 256line 255 didn't jump to line 256 because the pattern on line 255 never matched

256 assert isinstance(out_var, BitRegister) 

257 self.circ.add_c_setbits( 

258 [True] * out_var.size, out_var.to_list(), **self.kwargs 

259 ) 

260 case ClOp.RegOr: 

261 self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

262 case ClOp.RegXor: 

263 self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

264 case ClOp.RegZero: 264 ↛ 265line 264 didn't jump to line 265 because the pattern on line 264 never matched

265 assert isinstance(out_var, BitRegister) 

266 self.circ.add_c_setbits( 

267 [False] * out_var.size, out_var.to_list(), **self.kwargs 

268 ) 

269 case _: 

270 raise DecomposeClassicalError( 

271 f"{op} cannot be decomposed to TKET primitives." 

272 ) 

273 return out_var 

274 

275 

276def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: # noqa: PLR0912, PLR0915 

277 """Rewrite a circuit command-wise, decomposing ClExprOp.""" 

278 if not check_register_alignments(circ): 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true

279 raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.") 

280 bit_heap = BitHeap() 

281 reg_heap = RegHeap() 

282 # add already used heap variables to heaps 

283 for b in circ.bits: 

284 if b.reg_name == _TEMP_BIT_NAME: 

285 bit_heap._heap_vars.add(b) # noqa: SLF001 

286 elif b.reg_name.startswith(_TEMP_BIT_REG_BASE): 

287 reg_heap._heap_vars.add(BitRegister(b.reg_name, _TEMP_REG_SIZE)) # noqa: SLF001 

288 

289 newcirc = Circuit(0, name=circ.name) 

290 

291 for qb in circ.qubits: 

292 newcirc.add_qubit(qb) 

293 for cb in circ.bits: 

294 # lose all temporary bits, add back as required later 

295 if not ( 

296 cb.reg_name.startswith(_TEMP_BIT_NAME) 

297 or cb.reg_name.startswith(_TEMP_BIT_REG_BASE) 

298 ): 

299 newcirc.add_bit(cb) 

300 

301 # targets of predicates that need to be relabelled 

302 replace_targets: dict[Variable, Variable] = {} 

303 modified = False 

304 for command in circ: 

305 op = command.op 

306 optype = op.type 

307 args = command.args 

308 kwargs = {} 

309 if optype == OpType.Conditional: 

310 assert isinstance(op, Conditional) 

311 bits = args[: op.width] 

312 # check if conditional on previously decomposed expression 

313 if len(bits) == 1 and bits[0] in replace_targets: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true

314 assert isinstance(bits[0], Bit) 

315 # this op should encode comparison and value 

316 assert op.value in (0, 1) 

317 replace_bit = replace_targets[bits[0]] 

318 # temporary condition bit is available for reuse 

319 bit_heap.push(replace_bit) # type: ignore 

320 

321 # write new conditional op 

322 kwargs = {"condition_bits": [replace_bit], "condition_value": op.value} 

323 else: 

324 kwargs = {"condition_bits": bits, "condition_value": op.value} 

325 args = args[op.width :] 

326 op = op.op 

327 optype = op.type 

328 

329 if optype == OpType.RangePredicate: 

330 target = args[-1] 

331 assert isinstance(target, Bit) 

332 newcirc.add_bit(target, reject_dups=False) 

333 temp_reg = temp_reg_in_args(args) # type: ignore 

334 # ensure predicate is reading from correct output register 

335 if temp_reg in replace_targets: 335 ↛ 336line 335 didn't jump to line 336 because the condition on line 335 was never true

336 assert temp_reg is not None 

337 new_target = replace_targets[temp_reg] 

338 for i, a in enumerate(args): 

339 if a.reg_name == temp_reg.name: 

340 args[i] = Bit(new_target.name, a.index[0]) # type: ignore 

341 # operations conditional on this bit should remain so 

342 replace_targets[target] = target 

343 

344 elif optype == OpType.ClExpr: 

345 assert isinstance(op, ClExprOp) 

346 wexpr: WiredClExpr = op.expr 

347 expr: ClExpr = wexpr.expr 

348 bit_posn = wexpr.bit_posn 

349 reg_posn = wexpr.reg_posn 

350 output_posn = wexpr.output_posn 

351 assert len(output_posn) > 0 

352 output0 = args[output_posn[0]] 

353 assert isinstance(output0, Bit) 

354 out_var: Variable = ( 

355 BitRegister(output0.reg_name, len(output_posn)) 

356 if has_reg_output(expr.op) 

357 else output0 

358 ) 

359 decomposer = _ClExprDecomposer( 

360 newcirc, 

361 bit_posn, 

362 reg_posn, 

363 args, # type: ignore 

364 bit_heap, 

365 reg_heap, 

366 kwargs, 

367 ) 

368 comp_var = decomposer.decompose_expr(expr, out_var) 

369 if comp_var != out_var: 369 ↛ 370line 369 didn't jump to line 370 because the condition on line 369 was never true

370 replace_targets[out_var] = comp_var 

371 modified = True 

372 continue 

373 

374 if optype == OpType.Barrier: 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was never true

375 # add_gate doesn't work for metaops 

376 newcirc.add_barrier(args) 

377 else: 

378 for arg in args: 

379 if isinstance(arg, Bit) and arg not in newcirc.bits: 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true

380 newcirc.add_bit(arg) 

381 newcirc.add_gate(op, args, **kwargs) 

382 return newcirc, modified