Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/distribution.py: 89%

114 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 15:08 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import warnings 

16from collections import Counter, defaultdict 

17from collections.abc import Callable 

18from typing import Any, Generic, TypeVar, Union 

19 

20import numpy as np 

21from scipy.stats import rv_discrete 

22 

23Number = Union[float, complex] # noqa: UP007 

24T0 = TypeVar("T0") 

25T1 = TypeVar("T1") 

26 

27 

28class EmpiricalDistribution(Generic[T0]): 

29 """Represents an empirical distribution of values. 

30 

31 Supports methods for combination, marginalization, expectation value, etc. 

32 

33 >>> dist1 = EmpiricalDistribution(Counter({(0, 0): 3, (0, 1): 2, (1, 0): 4, (1, 1): 

34 ... 0})) 

35 >>> dist2 = EmpiricalDistribution(Counter({(0, 0): 1, (0, 1): 0, (1, 0): 2, (1, 1): 

36 ... 1})) 

37 >>> dist1.sample_mean(lambda x : x[0] + 2*x[1]) 

38 0.8888888888888888 

39 >>> dist3 = dist2.condition(lambda x: x[0] == 1) 

40 >>> dist3 

41 EmpiricalDistribution(Counter({(1, 0): 2, (1, 1): 1})) 

42 >>> dist4 = dist1 + dist3 

43 >>> dist4 

44 EmpiricalDistribution(Counter({(1, 0): 6, (0, 0): 3, (0, 1): 2, (1, 1): 1})) 

45 """ 

46 

47 def __init__(self, C: Counter[T0]): 

48 self._C: Counter[T0] = Counter({x: c for x, c in C.items() if c > 0}) 

49 

50 def as_counter(self) -> Counter[T0]: 

51 """Return the distribution as a :py:class:`collections.Counter` object.""" 

52 return self._C 

53 

54 @property 

55 def total(self) -> int: 

56 """Return the total number of observations.""" 

57 return self._C.total() 

58 

59 @property 

60 def support(self) -> set[T0]: 

61 """Return the support of the distribution (set of all observations).""" 

62 return set(self._C.keys()) 

63 

64 def __eq__(self, other: object) -> bool: 

65 """Compare distributions for equality.""" 

66 if not isinstance(other, EmpiricalDistribution): 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true

67 return NotImplemented 

68 return self._C == other._C 

69 

70 def __repr__(self) -> str: 

71 return f"{self.__class__.__name__}({self._C!r})" 

72 

73 def __getitem__(self, x: T0) -> int: 

74 """Get the count associated with an observation.""" 

75 return self._C[x] 

76 

77 def __add__( 

78 self, other: "EmpiricalDistribution[T0]" 

79 ) -> "EmpiricalDistribution[T0]": 

80 """Combine two distributions.""" 

81 return EmpiricalDistribution(self._C + other._C) 

82 

83 def condition(self, criterion: Callable[[T0], bool]) -> "EmpiricalDistribution[T0]": 

84 """Return a new distribution conditioned on the given criterion. 

85 

86 :param criterion: A boolean function defined on all possible observations. 

87 """ 

88 return EmpiricalDistribution( 

89 Counter({x: c for x, c in self._C.items() if criterion(x)}) 

90 ) 

91 

92 def map(self, mapping: Callable[[T0], T1]) -> "EmpiricalDistribution[T1]": 

93 """Return a distribution over a transformed domain. 

94 

95 The provided function maps elements in the original domain to new elements. If 

96 it is not injective, counts are combined. 

97 

98 :param mapping: A function defined on all possible observations, mapping them 

99 to another domain. 

100 """ 

101 C: Counter[T1] = Counter() 

102 for x, c in self._C.items(): 

103 C[mapping(x)] += c 

104 return EmpiricalDistribution(C) 

105 

106 def sample_mean(self, f: Callable[[T0], Number]) -> Number: 

107 """Compute the sample mean of a functional. 

108 

109 The provided function maps observations to numerical values. 

110 

111 :return: Estimate of the mean of the functional based on the observations.""" 

112 return sum(c * f(x) for x, c in self._C.items()) / self.total 

113 

114 def sample_variance(self, f: Callable[[T0], Number]) -> Number: 

115 """Compute the sample variance of a functional. 

116 

117 The provided function maps observations to numerical values. 

118 

119 The sample variance is an unbiased estimate of the variance of the underlying 

120 distribution. 

121 

122 :return: Estimate of the variance of the functional based on the 

123 observations.""" 

124 if self.total < 2: # noqa: PLR2004 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true

125 raise RuntimeError( 

126 "At least two samples are required in order to compute the sample " 

127 "variance." 

128 ) 

129 fs = [(f(x), c) for x, c in self._C.items()] 

130 M0 = self.total 

131 M1 = sum(c * v for v, c in fs) 

132 M2 = sum(c * v**2 for v, c in fs) 

133 return (M2 - M1**2 / M0) / (M0 - 1) 

134 

135 

136class ProbabilityDistribution(Generic[T0]): 

137 """Represents an exact probability distribution. 

138 

139 Supports methods for combination, marginalization, expectation value, etc. May be 

140 derived from an :py:class:`EmpriricalDistribution`. 

141 """ 

142 

143 def __init__(self, P: dict[T0, float], min_p: float = 0.0): 

144 """Initialize with a dictionary of probabilities. 

145 

146 :param P: Dictionary of probabilities. 

147 :param min_p: Optional probability below which to ignore values. Default 

148 0. Distribution is renormalized after removing these values. 

149 

150 The values must be non-negative. If they do not sum to 1, a warning is 

151 emitted; the distribution will contain normalized values. 

152 """ 

153 if any(x < 0 for x in P.values()): 

154 raise ValueError("Distribution contains negative probabilities") 

155 s0 = sum(P.values()) 

156 if np.isclose(s0, 0): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 raise ValueError("Distribution has zero weight") 

158 if not np.isclose(s0, 1): 

159 warnings.warn( 

160 "Probabilities used to initialize ProbabilityDistribution do " 

161 "not sum to 1: renormalizing.", 

162 stacklevel=2, 

163 ) 

164 newP = {x: p for x, p in P.items() if p > min_p} 

165 s = sum(newP.values()) 

166 self._P = {x: p / s for x, p in newP.items()} 

167 

168 def as_dict(self) -> dict[T0, float]: 

169 """Return the distribution as a :py:class:`dict` object.""" 

170 return self._P 

171 

172 def as_rv_discrete(self) -> tuple[rv_discrete, list[T0]]: 

173 """Return the distribution as a :py:class:`scipy.stats.rv_discrete` object. 

174 

175 This method returns an RV over integers {0, 1, ..., k-1} where k is the size of 

176 the support, and a list whose i'th member is the item corresponding to the value 

177 i of the RV. 

178 """ 

179 X = list(self._P.keys()) 

180 return (rv_discrete(values=(range(len(X)), [self._P[x] for x in X])), X) 

181 

182 @property 

183 def support(self) -> set[T0]: 

184 """Return the support of the distribution (set of all possible outcomes).""" 

185 return set(self._P.keys()) 

186 

187 def __eq__(self, other: object) -> bool: 

188 """Compare distributions for equality.""" 

189 if not isinstance(other, ProbabilityDistribution): 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true

190 return NotImplemented 

191 keys0 = frozenset(self._P.keys()) 

192 keys1 = frozenset(other._P.keys()) 

193 if keys0 != keys1: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 return False 

195 return all(np.isclose(self._P[x], other._P[x]) for x in keys0) # noqa: SLF001 

196 

197 def __repr__(self) -> str: 

198 return f"{self.__class__.__name__}({self._P!r})" 

199 

200 def __getitem__(self, x: T0) -> float: 

201 """Get the probability associated with a possible outcome.""" 

202 return self._P.get(x, 0.0) 

203 

204 @classmethod 

205 def from_empirical_distribution( 

206 cls, ed: EmpiricalDistribution[T0] 

207 ) -> "ProbabilityDistribution[T0]": 

208 """Estimate a probability distribution from an empirical distribution.""" 

209 S = ed.total 

210 if S == 0: 

211 raise ValueError("Empirical distribution has no values") 

212 f = 1 / S 

213 return cls({x: f * c for x, c in ed.as_counter().items()}) 

214 

215 def condition( 

216 self, criterion: Callable[[T0], bool] 

217 ) -> "ProbabilityDistribution[T0]": 

218 """Return a new distribution conditioned on the given criterion. 

219 

220 :param criterion: A boolean function defined on all possible outcomes. 

221 """ 

222 S = sum(c for x, c in self._P.items() if criterion(x)) 

223 if np.isclose(S, 0): 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true

224 raise ValueError("Condition has probability zero") 

225 f = 1 / S 

226 return ProbabilityDistribution( 

227 {x: f * c for x, c in self._P.items() if criterion(x)} 

228 ) 

229 

230 def map(self, mapping: Callable[[T0], T1]) -> "ProbabilityDistribution[T1]": 

231 """Return a distribution over a transformed domain. 

232 

233 The provided function maps elements in the original domain to new elements. If 

234 it is not injective, probabilities are combined. 

235 

236 :param mapping: A function defined on all possible outcomes, mapping them to 

237 another domain. 

238 """ 

239 P: defaultdict[Any, float] = defaultdict(float) 

240 for x, p in self._P.items(): 

241 P[mapping(x)] += p 

242 return ProbabilityDistribution(P) 

243 

244 def expectation(self, f: Callable[[T0], Number]) -> Number: 

245 """Compute the expectation value of a functional. 

246 

247 The provided function maps possible outcomes to numerical values. 

248 

249 :return: Expectation of the functional. 

250 """ 

251 return sum(p * f(x) for x, p in self._P.items()) 

252 

253 def variance(self, f: Callable[[T0], Number]) -> Number: 

254 """Compute the variance of a functional. 

255 

256 The provided function maps possible outcomes to numerical values. 

257 

258 :return: Variance of the functional. 

259 """ 

260 fs = [(f(x), p) for x, p in self._P.items()] 

261 return sum(p * v**2 for v, p in fs) - (sum(p * v for v, p in fs)) ** 2 

262 

263 

264def convex_combination( 

265 dists: list[tuple[ProbabilityDistribution[T0], float]], 

266) -> ProbabilityDistribution[T0]: 

267 """Return a convex combination of probability distributions. 

268 

269 Each pair in the list comprises a distribution and a weight. The weights must be 

270 non-negative and sum to 1. 

271 

272 >>> dist1 = ProbabilityDistribution({0: 0.25, 1: 0.5, 2: 0.25}) 

273 >>> dist2 = ProbabilityDistribution({0: 0.5, 1: 0.5}) 

274 >>> dist3 = convex_combination([(dist1, 0.25), (dist2, 0.75)]) 

275 >>> dist3 

276 ProbabilityDistribution({0: 0.4375, 1: 0.5, 2: 0.0625}) 

277 >>> dist3.expectation(lambda x : x**2) 

278 0.75 

279 """ 

280 P: defaultdict[T0, float] = defaultdict(float) 

281 S = 0.0 

282 for pd, a in dists: 

283 if a < 0: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true

284 raise ValueError("Weights must be non-negative.") 

285 for x, p in pd._P.items(): # noqa: SLF001 

286 P[x] += a * p 

287 S += a 

288 if not np.isclose(S, 1): 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true

289 raise ValueError("Weights must sum to 1.") 

290 return ProbabilityDistribution(P)