Coverage for /home/runner/work/tket/tket/pytket/pytket/circuit/ 83%

213 statements  

« prev     ^ index     » next v7.6.12, created at 2025-03-14 11:30 +0000

1# Copyright Quantinuum 


3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 




9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 


15"""Functions for decomposing Circuits containing classical expressions 

16in to primitive logical operations.""" 


18from heapq import heappop, heappush 

19from typing import Any, Generic, TypeVar 


21from pytket._tket.circuit import ( 

22 Circuit, 

23 ClBitVar, 

24 ClExpr, 

25 ClExprOp, 

26 ClOp, 

27 ClRegVar, 

28 Conditional, 

29 OpType, 

30 WiredClExpr, 


32from pytket._tket.unit_id import ( 




36 Bit, 

37 BitRegister, 


39from pytket.circuit.clexpr import check_register_alignments, has_reg_output 

40from pytket.circuit.logic_exp import Constant, Variable 


42T = TypeVar("T") 



45class DecomposeClassicalError(Exception): 

46 """Error with decomposing classical operations.""" 



49class VarHeap(Generic[T]): 

50 """A generic heap implementation.""" 


52 def __init__(self) -> None: 

53 self._heap: list[T] = [] 

54 self._heap_vars: set[T] = set() 


56 def pop(self) -> T: 

57 """Pop from top of heap.""" 

58 return heappop(self._heap) 


60 def push(self, var: T) -> None: 

61 """Push var to heap.""" 

62 heappush(self._heap, var) 

63 self._heap_vars.add(var) 


65 def is_heap_var(self, var: T) -> bool: 

66 """Check if var was generated from heap.""" 

67 return var in self._heap_vars 


69 def fresh_var(self) -> T: 

70 """Generate new variable.""" 

71 raise NotImplementedError 



74class BitHeap(VarHeap[Bit]): 

75 """Heap of temporary Bits.""" 


77 def __init__(self, _reg_name: str = _TEMP_BIT_NAME): 

78 """Initialise new BitHeap. 


80 :param _reg_name: Name for register of Bits, defaults to _TEMP_BIT_NAME 

81 :type _reg_name: str, optional 

82 """ 


84 self.reg_name = _reg_name 

85 super().__init__() 


87 @property 

88 def next_index(self) -> int: 

89 """Next available bit index, not used by any other heap bit.""" 

90 return max((b.index[0] for b in self._heap_vars), default=-1) + 1 


92 def fresh_var(self) -> Bit: 

93 """Return Bit, from heap if available, otherwise create new.""" 

94 if self._heap: 

95 return self.pop() 

96 new_bit = Bit(self.reg_name, self.next_index) 

97 self._heap_vars.add(new_bit) 

98 return new_bit 



101class RegHeap(VarHeap[BitRegister]): 

102 """Heap of temporary BitRegisters.""" 


104 def __init__(self, _reg_name_base: str = _TEMP_BIT_REG_BASE): 

105 """Initialise new RegHeap. 


107 :param _reg_name_base: base string for register names, defaults to 


109 :type _reg_name_base: str, optional 

110 """ 

111 self._reg_name_base = _reg_name_base 

112 super().__init__() 


114 @property 

115 def next_index(self) -> int: 

116 """Next available bit index, not used by any other heap register.""" 

117 return ( 

118 max((int("_")[-1]) for b in self._heap_vars), default=-1) + 1 

119 ) 


121 def fresh_var(self, size: int = _TEMP_REG_SIZE) -> BitRegister: 

122 """Return BitRegister, from heap if available, otherwise create new. 

123 Optionally set size of created register.""" 

124 if self._heap: 

125 return self.pop() 

126 new_reg = BitRegister(f"{self._reg_name_base}_{self.next_index}", size) 

127 self._heap_vars.add(new_reg) 


129 return new_reg 



132def temp_reg_in_args(args: list[Bit]) -> BitRegister | None: 

133 """If there are bits from a temporary register in the args, return it.""" 

134 temp_reg_bits = [b for b in args if b.reg_name.startswith(_TEMP_BIT_REG_BASE)] 

135 if temp_reg_bits: 

136 return BitRegister(temp_reg_bits[0].reg_name, _TEMP_REG_SIZE) 

137 return None 



140VarType = TypeVar("VarType", type[Bit], type[BitRegister]) 



143def int_to_bools(val: Constant, width: int) -> list[bool]: 

144 # map int to bools via litle endian encoding 

145 return list(map(bool, map(int, reversed(f"{val:0{width}b}"[-width:])))) 



148def get_bit_width(x: int) -> int: 

149 assert x >= 0 

150 c = 0 

151 while x: 

152 x >>= 1 

153 c += 1 

154 return c 



157class ClExprDecomposer: 

158 def __init__( 

159 self, 

160 circ: Circuit, 

161 bit_posn: dict[int, int], 

162 reg_posn: dict[int, list[int]], 

163 args: list[Bit], 

164 bit_heap: BitHeap, 

165 reg_heap: RegHeap, 

166 kwargs: dict[str, Any], 

167 ): 

168 self.circ: Circuit = circ 

169 self.bit_posn: dict[int, int] = bit_posn 

170 self.reg_posn: dict[int, list[int]] = reg_posn 

171 self.args: list[Bit] = args 

172 self.bit_heap: BitHeap = bit_heap 

173 self.reg_heap: RegHeap = reg_heap 

174 self.kwargs: dict[str, Any] = kwargs 

175 # Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar) 

176 # to BitRegister: 

177 self.bit_vars = {i: args[p] for i, p in bit_posn.items()} 

178 self.reg_vars = { 

179 i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items() 

180 } 


182 def add_var(self, var: Variable) -> None: 

183 """Add a Bit or BitRegister to the circuit if not already present.""" 

184 if isinstance(var, Bit): 

185 self.circ.add_bit(var, reject_dups=False) 

186 else: 

187 assert isinstance(var, BitRegister) 

188 for bit in var.to_list(): 

189 self.circ.add_bit(bit, reject_dups=False) 


191 def set_bits(self, var: Variable, val: int) -> None: 

192 """Set the value of a Bit or BitRegister.""" 

193 assert val >= 0 

194 if isinstance(var, Bit): 

195 assert val >> 1 == 0 

196 self.circ.add_c_setbits([bool(val)], [var], **self.kwargs) 

197 else: 

198 assert isinstance(var, BitRegister) 

199 assert val >> var.size == 0 

200 self.circ.add_c_setreg(val, var, **self.kwargs) 


202 def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable: 

203 """Add the decomposed expression to the circuit and return the Bit or 

204 BitRegister that contains the result. 


206 :param expr: the expression to decompose 

207 :param out_var: where to put the output (if None, create a new scratch location) 

208 """ 

209 op: ClOp = expr.op 

210 heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap 


212 # Eliminate (recursively) subsidiary expressions from the arguments, and convert 

213 # all terms to Bit or BitRegister: 

214 terms: list[Variable] = [] 

215 for arg in expr.args: 

216 if isinstance(arg, int): 

217 # Assign to a fresh variable 

218 fresh_var = heap.fresh_var() 

219 self.add_var(fresh_var) 

220 self.set_bits(fresh_var, arg) 

221 terms.append(fresh_var) 

222 elif isinstance(arg, ClBitVar): 

223 terms.append(self.bit_vars[arg.index]) 

224 elif isinstance(arg, ClRegVar): 

225 terms.append(self.reg_vars[arg.index]) 

226 else: 

227 assert isinstance(arg, ClExpr) 

228 terms.append(self.decompose_expr(arg, None)) 


230 # Enable reuse of temporary terms: 

231 for term in terms: 

232 if heap.is_heap_var(term): 

233 heap.push(term) 


235 if out_var is None: 

236 out_var = heap.fresh_var() 

237 self.add_var(out_var) 

238 match op: 

239 case ClOp.BitAnd: 

240 self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore 

241 case ClOp.BitNot: 241 ↛ 242line 241 didn't jump to line 242 because the pattern on line 241 never matched

242 self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore 

243 case ClOp.BitOne: 243 ↛ 244line 243 didn't jump to line 244 because the pattern on line 243 never matched

244 assert isinstance(out_var, Bit) 

245 self.circ.add_c_setbits([True], [out_var], **self.kwargs) 

246 case ClOp.BitOr: 

247 self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore 

248 case ClOp.BitXor: 

249 self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore 

250 case ClOp.BitZero: 250 ↛ 251line 250 didn't jump to line 251 because the pattern on line 250 never matched

251 assert isinstance(out_var, Bit) 

252 self.circ.add_c_setbits([False], [out_var], **self.kwargs) 

253 case ClOp.RegAnd: 

254 self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

255 case ClOp.RegNot: 255 ↛ 256line 255 didn't jump to line 256 because the pattern on line 255 never matched

256 self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

257 case ClOp.RegOne: 257 ↛ 258line 257 didn't jump to line 258 because the pattern on line 257 never matched

258 assert isinstance(out_var, BitRegister) 

259 self.circ.add_c_setbits( 

260 [True] * out_var.size, out_var.to_list(), **self.kwargs 

261 ) 

262 case ClOp.RegOr: 

263 self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

264 case ClOp.RegXor: 

265 self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore 

266 case ClOp.RegZero: 266 ↛ 267line 266 didn't jump to line 267 because the pattern on line 266 never matched

267 assert isinstance(out_var, BitRegister) 

268 self.circ.add_c_setbits( 

269 [False] * out_var.size, out_var.to_list(), **self.kwargs 

270 ) 

271 case _: 

272 raise DecomposeClassicalError( 

273 f"{op} cannot be decomposed to TKET primitives." 

274 ) 

275 return out_var 



278def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: 

279 """Rewrite a circuit command-wise, decomposing ClExprOp.""" 

280 if not check_register_alignments(circ): 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true

281 raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.") 

282 bit_heap = BitHeap() 

283 reg_heap = RegHeap() 

284 # add already used heap variables to heaps 

285 for b in circ.bits: 

286 if b.reg_name == _TEMP_BIT_NAME: 

287 bit_heap._heap_vars.add(b) 

288 elif b.reg_name.startswith(_TEMP_BIT_REG_BASE): 

289 reg_heap._heap_vars.add(BitRegister(b.reg_name, _TEMP_REG_SIZE)) 


291 newcirc = Circuit(0, 


293 for qb in circ.qubits: 

294 newcirc.add_qubit(qb) 

295 for cb in circ.bits: 

296 # lose all temporary bits, add back as required later 

297 if not ( 

298 cb.reg_name.startswith(_TEMP_BIT_NAME) 

299 or cb.reg_name.startswith(_TEMP_BIT_REG_BASE) 

300 ): 

301 newcirc.add_bit(cb) 


303 # targets of predicates that need to be relabelled 

304 replace_targets: dict[Variable, Variable] = dict() 

305 modified = False 

306 for command in circ: 

307 op = command.op 

308 optype = op.type 

309 args = command.args 

310 kwargs = dict() 

311 if optype == OpType.Conditional: 

312 assert isinstance(op, Conditional) 

313 bits = args[: op.width] 

314 # check if conditional on previously decomposed expression 

315 if len(bits) == 1 and bits[0] in replace_targets: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true

316 assert isinstance(bits[0], Bit) 

317 # this op should encode comparison and value 

318 assert op.value in (0, 1) 

319 replace_bit = replace_targets[bits[0]] 

320 # temporary condition bit is available for reuse 

321 bit_heap.push(replace_bit) # type: ignore 


323 # write new conditional op 

324 kwargs = {"condition_bits": [replace_bit], "condition_value": op.value} 

325 else: 

326 kwargs = {"condition_bits": bits, "condition_value": op.value} 

327 args = args[op.width :] 

328 op = op.op 

329 optype = op.type 


331 if optype == OpType.RangePredicate: 

332 target = args[-1] 

333 assert isinstance(target, Bit) 

334 newcirc.add_bit(target, reject_dups=False) 

335 temp_reg = temp_reg_in_args(args) # type: ignore 

336 # ensure predicate is reading from correct output register 

337 if temp_reg in replace_targets: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true

338 assert temp_reg is not None 

339 new_target = replace_targets[temp_reg] 

340 for i, a in enumerate(args): 

341 if a.reg_name == 

342 args[i] = Bit(, a.index[0]) # type: ignore 

343 # operations conditional on this bit should remain so 

344 replace_targets[target] = target 


346 elif optype == OpType.ClExpr: 

347 assert isinstance(op, ClExprOp) 

348 wexpr: WiredClExpr = op.expr 

349 expr: ClExpr = wexpr.expr 

350 bit_posn = wexpr.bit_posn 

351 reg_posn = wexpr.reg_posn 

352 output_posn = wexpr.output_posn 

353 assert len(output_posn) > 0 

354 output0 = args[output_posn[0]] 

355 assert isinstance(output0, Bit) 

356 out_var: Variable = ( 

357 BitRegister(output0.reg_name, len(output_posn)) 

358 if has_reg_output(expr.op) 

359 else output0 

360 ) 

361 decomposer = ClExprDecomposer( 

362 newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs # type: ignore 

363 ) 

364 comp_var = decomposer.decompose_expr(expr, out_var) 

365 if comp_var != out_var: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 replace_targets[out_var] = comp_var 

367 modified = True 

368 continue 


370 if optype == OpType.Barrier: 370 ↛ 372line 370 didn't jump to line 372 because the condition on line 370 was never true

371 # add_gate doesn't work for metaops 

372 newcirc.add_barrier(args) 

373 else: 

374 for arg in args: 

375 if isinstance(arg, Bit) and arg not in newcirc.bits: 375 ↛ 376line 375 didn't jump to line 376 because the condition on line 375 was never true

376 newcirc.add_bit(arg) 

377 newcirc.add_gate(op, args, **kwargs) 

378 return newcirc, modified