Coverage for /home/runner/work/tket/tket/pytket/pytket/circuit/add_condition.py: 94%
56 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Enable adding of gates with conditions on Bit or BitRegister expressions."""
17from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE
18from pytket.circuit import Bit, BitRegister, Circuit
19from pytket.circuit.clexpr import wired_clexpr_from_logic_exp
20from pytket.circuit.logic_exp import (
21 BitLogicExp,
22 Constant,
23 PredicateExp,
24 RegEq,
25 RegGeq,
26 RegGt,
27 RegLeq,
28 RegLogicExp,
29 RegLt,
30 RegNeq,
31)
34class NonConstError(Exception):
35 """A custom exception class for non constant predicate argument."""
38def _add_condition(
39 circ: Circuit, condition: PredicateExp | Bit | BitLogicExp
40) -> tuple[Bit, bool]:
41 """Add a condition expression to a circuit using classical expression boxes,
42 rangepredicates and conditionals. Return predicate bit and value of said bit.
43 """
44 if isinstance(condition, Bit):
45 return condition, True
46 if isinstance(condition, PredicateExp):
47 pred_exp, pred_val = condition.args
48 # PredicateExp constructor should ensure arg order
49 if not isinstance(pred_val, Constant): 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true
50 raise NonConstError(
51 "Condition expressions must be of type `PredicateExp`\
52 with a constant second operand."
53 )
54 elif isinstance(condition, BitLogicExp): 54 ↛ 58line 54 didn't jump to line 58 because the condition on line 54 was always true
55 pred_val = 1
56 pred_exp = condition
57 else:
58 raise ValueError(
59 f"Condition {condition} must be of type Bit, BitLogicExp or PredicateExp"
60 )
62 next_index = (
63 max(
64 (bit.index[0] for bit in circ.bits if bit.reg_name == _TEMP_BIT_NAME),
65 default=-1,
66 )
67 + 1
68 )
69 if isinstance(pred_exp, Bit):
70 return pred_exp, bool(pred_val)
72 # the resulting condition (a boolean) will be written to this
73 # scratch bit
74 condition_bit = Bit(_TEMP_BIT_NAME, next_index)
75 circ.add_bit(condition_bit)
77 if isinstance(pred_exp, BitLogicExp):
78 wexpr, args = wired_clexpr_from_logic_exp(pred_exp, [condition_bit])
79 circ.add_clexpr(wexpr, args)
80 return condition_bit, bool(pred_val)
82 assert isinstance(pred_exp, (RegLogicExp, BitRegister))
83 if isinstance(pred_exp, RegLogicExp):
84 inps = pred_exp.all_inputs_ordered()
85 reg_sizes: list[int] = []
86 for reg in inps:
87 assert isinstance(reg, BitRegister)
88 reg_sizes.append(reg.size)
89 min_reg_size = min(reg_sizes)
90 existing_reg_names = set(
91 bit.reg_name
92 for bit in circ.bits
93 if bit.reg_name.startswith(_TEMP_BIT_REG_BASE)
94 )
95 existing_reg_indices = (
96 int(r_name.split("_")[-1]) for r_name in existing_reg_names
97 )
98 next_index = max(existing_reg_indices, default=-1) + 1
99 temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", min_reg_size)
100 circ.add_c_register(temp_reg)
101 target_bits = temp_reg.to_list()
102 wexpr, args = wired_clexpr_from_logic_exp(pred_exp, target_bits)
103 circ.add_clexpr(wexpr, args)
104 elif isinstance(pred_exp, BitRegister): 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true
105 target_bits = pred_exp.to_list()
107 minval = 0
108 maxval = (1 << 64) - 1
109 if isinstance(condition, RegLt):
110 maxval = pred_val - 1
111 elif isinstance(condition, RegGt):
112 minval = pred_val + 1
113 if isinstance(condition, (RegLeq, RegEq, RegNeq)):
114 maxval = pred_val
115 if isinstance(condition, (RegGeq, RegEq, RegNeq)):
116 minval = pred_val
118 circ.add_c_range_predicate(minval, maxval, target_bits, condition_bit)
119 condition_value = not isinstance(condition, RegNeq)
120 return condition_bit, condition_value