Coverage for /home/runner/work/tket/tket/pytket/pytket/circuit/add_condition.py: 94%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 11:30 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Enable adding of gates with conditions on Bit or BitRegister expressions.""" 

16 

17from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE 

18from pytket.circuit import Bit, BitRegister, Circuit 

19from pytket.circuit.clexpr import wired_clexpr_from_logic_exp 

20from pytket.circuit.logic_exp import ( 

21 BitLogicExp, 

22 Constant, 

23 PredicateExp, 

24 RegEq, 

25 RegGeq, 

26 RegGt, 

27 RegLeq, 

28 RegLogicExp, 

29 RegLt, 

30 RegNeq, 

31) 

32 

33 

34class NonConstError(Exception): 

35 """A custom exception class for non constant predicate argument.""" 

36 

37 

38def _add_condition( 

39 circ: Circuit, condition: PredicateExp | Bit | BitLogicExp 

40) -> tuple[Bit, bool]: 

41 """Add a condition expression to a circuit using classical expression boxes, 

42 rangepredicates and conditionals. Return predicate bit and value of said bit. 

43 """ 

44 if isinstance(condition, Bit): 

45 return condition, True 

46 if isinstance(condition, PredicateExp): 

47 pred_exp, pred_val = condition.args 

48 # PredicateExp constructor should ensure arg order 

49 if not isinstance(pred_val, Constant): 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true

50 raise NonConstError( 

51 "Condition expressions must be of type `PredicateExp`\ 

52 with a constant second operand." 

53 ) 

54 elif isinstance(condition, BitLogicExp): 54 ↛ 58line 54 didn't jump to line 58 because the condition on line 54 was always true

55 pred_val = 1 

56 pred_exp = condition 

57 else: 

58 raise ValueError( 

59 f"Condition {condition} must be of type Bit, BitLogicExp or PredicateExp" 

60 ) 

61 

62 next_index = ( 

63 max( 

64 (bit.index[0] for bit in circ.bits if bit.reg_name == _TEMP_BIT_NAME), 

65 default=-1, 

66 ) 

67 + 1 

68 ) 

69 if isinstance(pred_exp, Bit): 

70 return pred_exp, bool(pred_val) 

71 

72 # the resulting condition (a boolean) will be written to this 

73 # scratch bit 

74 condition_bit = Bit(_TEMP_BIT_NAME, next_index) 

75 circ.add_bit(condition_bit) 

76 

77 if isinstance(pred_exp, BitLogicExp): 

78 wexpr, args = wired_clexpr_from_logic_exp(pred_exp, [condition_bit]) 

79 circ.add_clexpr(wexpr, args) 

80 return condition_bit, bool(pred_val) 

81 

82 assert isinstance(pred_exp, (RegLogicExp, BitRegister)) 

83 if isinstance(pred_exp, RegLogicExp): 

84 inps = pred_exp.all_inputs_ordered() 

85 reg_sizes: list[int] = [] 

86 for reg in inps: 

87 assert isinstance(reg, BitRegister) 

88 reg_sizes.append(reg.size) 

89 min_reg_size = min(reg_sizes) 

90 existing_reg_names = set( 

91 bit.reg_name 

92 for bit in circ.bits 

93 if bit.reg_name.startswith(_TEMP_BIT_REG_BASE) 

94 ) 

95 existing_reg_indices = ( 

96 int(r_name.split("_")[-1]) for r_name in existing_reg_names 

97 ) 

98 next_index = max(existing_reg_indices, default=-1) + 1 

99 temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", min_reg_size) 

100 circ.add_c_register(temp_reg) 

101 target_bits = temp_reg.to_list() 

102 wexpr, args = wired_clexpr_from_logic_exp(pred_exp, target_bits) 

103 circ.add_clexpr(wexpr, args) 

104 elif isinstance(pred_exp, BitRegister): 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true

105 target_bits = pred_exp.to_list() 

106 

107 minval = 0 

108 maxval = (1 << 64) - 1 

109 if isinstance(condition, RegLt): 

110 maxval = pred_val - 1 

111 elif isinstance(condition, RegGt): 

112 minval = pred_val + 1 

113 if isinstance(condition, (RegLeq, RegEq, RegNeq)): 

114 maxval = pred_val 

115 if isinstance(condition, (RegGeq, RegEq, RegNeq)): 

116 minval = pred_val 

117 

118 circ.add_c_range_predicate(minval, maxval, target_bits, condition_bit) 

119 condition_value = not isinstance(condition, RegNeq) 

120 return condition_bit, condition_value