Coverage for /home/runner/work/tket/tket/pytket/pytket/circuit/decompose_classical.py: 83%

213 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 11:30 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Functions for decomposing Circuits containing classical expressions 

16in to primitive logical operations.""" 

17 

18from heapq import heappop, heappush 

19from typing import Any, Generic, TypeVar 

20 

21from pytket._tket.circuit import ( 

22 Circuit, 

23 ClBitVar, 

24 ClExpr, 

25 ClExprOp, 

26 ClOp, 

27 ClRegVar, 

28 Conditional, 

29 OpType, 

30 WiredClExpr, 

31) 

32from pytket._tket.unit_id import ( 

33 _TEMP_BIT_NAME, 

34 _TEMP_BIT_REG_BASE, 

35 _TEMP_REG_SIZE, 

36 Bit, 

37 BitRegister, 

38) 

39from pytket.circuit.clexpr import check_register_alignments, has_reg_output 

40from pytket.circuit.logic_exp import Constant, Variable 

41 

42T = TypeVar("T") 

43 

44 

45class DecomposeClassicalError(Exception): 

46 """Error with decomposing classical operations.""" 

47 

48 

49class VarHeap(Generic[T]): 

50 """A generic heap implementation.""" 

51 

52 def __init__(self) -> None: 

53 self._heap: list[T] = [] 

54 self._heap_vars: set[T] = set() 

55 

56 def pop(self) -> T: 

57 """Pop from top of heap.""" 

58 return heappop(self._heap) 

59 

60 def push(self, var: T) -> None: 

61 """Push var to heap.""" 

62 heappush(self._heap, var) 

63 self._heap_vars.add(var) 

64 

65 def is_heap_var(self, var: T) -> bool: 

66 """Check if var was generated from heap.""" 

67 return var in self._heap_vars 

68 

69 def fresh_var(self) -> T: 

70 """Generate new variable.""" 

71 raise NotImplementedError 

72 

73 

74class BitHeap(VarHeap[Bit]): 

75 """Heap of temporary Bits.""" 

76 

77 def __init__(self, _reg_name: str = _TEMP_BIT_NAME): 

78 """Initialise new BitHeap. 

79 

80 :param _reg_name: Name for register of Bits, defaults to _TEMP_BIT_NAME 

81 :type _reg_name: str, optional 

82 """ 

83 

84 self.reg_name = _reg_name 

85 super().__init__() 

86 

87 @property 

88 def next_index(self) -> int: 

89 """Next available bit index, not used by any other heap bit.""" 

90 return max((b.index[0] for b in self._heap_vars), default=-1) + 1 

91 

92 def fresh_var(self) -> Bit: 

93 """Return Bit, from heap if available, otherwise create new.""" 

94 if self._heap: 

95 return self.pop() 

96 new_bit = Bit(self.reg_name, self.next_index) 

97 self._heap_vars.add(new_bit) 

98 return new_bit 

99 

100 

101class RegHeap(VarHeap[BitRegister]): 

102 """Heap of temporary BitRegisters.""" 

103 

104 def __init__(self, _reg_name_base: str = _TEMP_BIT_REG_BASE): 

105 """Initialise new RegHeap. 

106 

107 :param _reg_name_base: base string for register names, defaults to 

108 _TEMP_BIT_REG_BASE 

109 :type _reg_name_base: str, optional 

110 """ 

111 self._reg_name_base = _reg_name_base 

112 super().__init__() 

113 

114 @property 

115 def next_index(self) -> int: 

116 """Next available bit index, not used by any other heap register.""" 

117 return ( 

118 max((int(b.name.split("_")[-1]) for b in self._heap_vars), default=-1) + 1 

119 ) 

120 

121 def fresh_var(self, size: int = _TEMP_REG_SIZE) -> BitRegister: 

122 """Return BitRegister, from heap if available, otherwise create new. 

123 Optionally set size of created register.""" 

124 if self._heap: 

125 return self.pop() 

126 new_reg = BitRegister(f"{self._reg_name_base}_{self.next_index}", size) 

127 self._heap_vars.add(new_reg) 

128 

129 return new_reg 

130 

131 

132def temp_reg_in_args(args: list[Bit]) -> BitRegister | None: 

133 """If there are bits from a temporary register in the args, return it.""" 

134 temp_reg_bits = [b for b in args if b.reg_name.startswith(_TEMP_BIT_REG_BASE)] 

135 if temp_reg_bits: 

136 return BitRegister(temp_reg_bits[0].reg_name, _TEMP_REG_SIZE) 

137 return None 

138 

139 

140VarType = TypeVar("VarType", type[Bit], type[BitRegister]) 

141 

142 

143def int_to_bools(val: Constant, width: int) -> list[bool]: 

144 # map int to bools via litle endian encoding 

145 return list(map(bool, map(int, reversed(f"{val:0{width}b}"[-width:])))) 

146 

147 

148def get_bit_width(x: int) -> int: 

149 assert x >= 0 

150 c = 0 

151 while x: 

152 x >>= 1 

153 c += 1 

154 return c 

155 

156 

157class ClExprDecomposer: 

158 def __init__( 

159 self, 

160 circ: Circuit, 

161 bit_posn: dict[int, int], 

162 reg_posn: dict[int, list[int]], 

163 args: list[Bit], 

164 bit_heap: BitHeap, 

165 reg_heap: RegHeap, 

166 kwargs: dict[str, Any], 

167 ): 

168 self.circ: Circuit = circ 

169 self.bit_posn: dict[int, int] = bit_posn 

170 self.reg_posn: dict[int, list[int]] = reg_posn 

171 self.args: list[Bit] = args 

172 self.bit_heap: BitHeap = bit_heap 

173 self.reg_heap: RegHeap = reg_heap 

174 self.kwargs: dict[str, Any] = kwargs 

175 # Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar) 

176 # to BitRegister: 

177 self.bit_vars = {i: args[p] for i, p in bit_posn.items()} 

178 self.reg_vars = { 

179 i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items() 

180 } 

181 

182 def add_var(self, var: Variable) -> None: 

183 """Add a Bit or BitRegister to the circuit if not already present.""" 

184 if isinstance(var, Bit): 

185 self.circ.add_bit(var, reject_dups=False) 

186 else: 

187 assert isinstance(var, BitRegister) 

188 for bit in var.to_list(): 

189 self.circ.add_bit(bit, reject_dups=False) 

190 

191 def set_bits(self, var: Variable, val: int) -> None: 

192 """Set the value of a Bit or BitRegister.""" 

193 assert val >= 0 

194 if isinstance(var, Bit): 

195 assert val >> 1 == 0 

196 self.circ.add_c_setbits([bool(val)], [var], **self.kwargs) 

197 else: 

198 assert isinstance(var, BitRegister) 

199 assert val >> var.size == 0 

200 self.circ.add_c_setreg(val, var, **self.kwargs) 

201 

202 def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable: 

203 """Add the decomposed expression to the circuit and return the Bit or 

204 BitRegister that contains the result. 

205 

206 :param expr: the expression to decompose 

207 :param out_var: where to put the output (if None, create a new scratch location) 

208 """ 

209 op: ClOp = expr.op 

210 heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap 

211 

212 # Eliminate (recursively) subsidiary expressions from the arguments, and convert 

213 # all terms to Bit or BitRegister: 

214 terms: list[Variable] = [] 

215 for arg in expr.args: 

216 if isinstance(arg, int): 

217 # Assign to a fresh variable 

218 fresh_var = heap.fresh_var() 

219 self.add_var(fresh_var) 

220 self.set_bits(fresh_var, arg) 

221 terms.append(fresh_var) 

222 elif isinstance(arg, ClBitVar): 

223 terms.append(self.bit_vars[arg.index]) 

224 elif isinstance(arg, ClRegVar): 

225 terms.append(self.reg_vars[arg.index]) 

226 else: 

227 assert isinstance(arg, ClExpr) 

228 terms.append(self.decompose_expr(arg, None)) 

229 

230 # Enable reuse of temporary terms: 

231 for term in terms: 

232 if heap.is_heap_var(term): 

233 heap.push(term) 

234 

235 if out_var is None: 

236 out_var = heap.fresh_var() 

237 self.add_var(out_var) 

238 match op: 

239 case ClOp.BitAnd: 

240 self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore 

241 case ClOp.BitNot: 241 ↛ 242line 241 didn't jump to line 242 because the pattern on line 241 never matched

242 self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore 

243 case ClOp.BitOne: 243 ↛ 244line 243 didn't jump to line 244 because the pattern on line 243 never matched

244 assert isinstance(out_var, Bit) 

245 self.circ.add_c_setbits([True], [out_var], **self.kwargs) 

246 case ClOp.BitOr: 

247 self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore 

248 case ClOp.BitXor: 

249 self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore 

250 case ClOp.BitZero: 250 ↛ 251line 250 didn't jump to line 251 because the pattern on line 250 never matched

251 assert isinstance(out_var, Bit) 

252 self.circ.add_c_setbits([False], [out_var], **self.kwargs) 

253 case ClOp.RegAnd: 

254 self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

255 case ClOp.RegNot: 255 ↛ 256line 255 didn't jump to line 256 because the pattern on line 255 never matched

256 self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

257 case ClOp.RegOne: 257 ↛ 258line 257 didn't jump to line 258 because the pattern on line 257 never matched

258 assert isinstance(out_var, BitRegister) 

259 self.circ.add_c_setbits( 

260 [True] * out_var.size, out_var.to_list(), **self.kwargs 

261 ) 

262 case ClOp.RegOr: 

263 self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

264 case ClOp.RegXor: 

265 self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

266 case ClOp.RegZero: 266 ↛ 267line 266 didn't jump to line 267 because the pattern on line 266 never matched

267 assert isinstance(out_var, BitRegister) 

268 self.circ.add_c_setbits( 

269 [False] * out_var.size, out_var.to_list(), **self.kwargs 

270 ) 

271 case _: 

272 raise DecomposeClassicalError( 

273 f"{op} cannot be decomposed to TKET primitives." 

274 ) 

275 return out_var 

276 

277 

278def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: 

279 """Rewrite a circuit command-wise, decomposing ClExprOp.""" 

280 if not check_register_alignments(circ): 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true

281 raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.") 

282 bit_heap = BitHeap() 

283 reg_heap = RegHeap() 

284 # add already used heap variables to heaps 

285 for b in circ.bits: 

286 if b.reg_name == _TEMP_BIT_NAME: 

287 bit_heap._heap_vars.add(b) 

288 elif b.reg_name.startswith(_TEMP_BIT_REG_BASE): 

289 reg_heap._heap_vars.add(BitRegister(b.reg_name, _TEMP_REG_SIZE)) 

290 

291 newcirc = Circuit(0, name=circ.name) 

292 

293 for qb in circ.qubits: 

294 newcirc.add_qubit(qb) 

295 for cb in circ.bits: 

296 # lose all temporary bits, add back as required later 

297 if not ( 

298 cb.reg_name.startswith(_TEMP_BIT_NAME) 

299 or cb.reg_name.startswith(_TEMP_BIT_REG_BASE) 

300 ): 

301 newcirc.add_bit(cb) 

302 

303 # targets of predicates that need to be relabelled 

304 replace_targets: dict[Variable, Variable] = dict() 

305 modified = False 

306 for command in circ: 

307 op = command.op 

308 optype = op.type 

309 args = command.args 

310 kwargs = dict() 

311 if optype == OpType.Conditional: 

312 assert isinstance(op, Conditional) 

313 bits = args[: op.width] 

314 # check if conditional on previously decomposed expression 

315 if len(bits) == 1 and bits[0] in replace_targets: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true

316 assert isinstance(bits[0], Bit) 

317 # this op should encode comparison and value 

318 assert op.value in (0, 1) 

319 replace_bit = replace_targets[bits[0]] 

320 # temporary condition bit is available for reuse 

321 bit_heap.push(replace_bit) # type: ignore 

322 

323 # write new conditional op 

324 kwargs = {"condition_bits": [replace_bit], "condition_value": op.value} 

325 else: 

326 kwargs = {"condition_bits": bits, "condition_value": op.value} 

327 args = args[op.width :] 

328 op = op.op 

329 optype = op.type 

330 

331 if optype == OpType.RangePredicate: 

332 target = args[-1] 

333 assert isinstance(target, Bit) 

334 newcirc.add_bit(target, reject_dups=False) 

335 temp_reg = temp_reg_in_args(args) # type: ignore 

336 # ensure predicate is reading from correct output register 

337 if temp_reg in replace_targets: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true

338 assert temp_reg is not None 

339 new_target = replace_targets[temp_reg] 

340 for i, a in enumerate(args): 

341 if a.reg_name == temp_reg.name: 

342 args[i] = Bit(new_target.name, a.index[0]) # type: ignore 

343 # operations conditional on this bit should remain so 

344 replace_targets[target] = target 

345 

346 elif optype == OpType.ClExpr: 

347 assert isinstance(op, ClExprOp) 

348 wexpr: WiredClExpr = op.expr 

349 expr: ClExpr = wexpr.expr 

350 bit_posn = wexpr.bit_posn 

351 reg_posn = wexpr.reg_posn 

352 output_posn = wexpr.output_posn 

353 assert len(output_posn) > 0 

354 output0 = args[output_posn[0]] 

355 assert isinstance(output0, Bit) 

356 out_var: Variable = ( 

357 BitRegister(output0.reg_name, len(output_posn)) 

358 if has_reg_output(expr.op) 

359 else output0 

360 ) 

361 decomposer = ClExprDecomposer( 

362 newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs # type: ignore 

363 ) 

364 comp_var = decomposer.decompose_expr(expr, out_var) 

365 if comp_var != out_var: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 replace_targets[out_var] = comp_var 

367 modified = True 

368 continue 

369 

370 if optype == OpType.Barrier: 370 ↛ 372line 370 didn't jump to line 372 because the condition on line 370 was never true

371 # add_gate doesn't work for metaops 

372 newcirc.add_barrier(args) 

373 else: 

374 for arg in args: 

375 if isinstance(arg, Bit) and arg not in newcirc.bits: 375 ↛ 376line 375 didn't jump to line 376 because the condition on line 375 was never true

376 newcirc.add_bit(arg) 

377 newcirc.add_gate(op, args, **kwargs) 

378 return newcirc, modified