Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/distribution.py: 88%

116 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-02 13:01 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import warnings 

16from collections import Counter, defaultdict 

17from collections.abc import Callable 

18from typing import Any, Generic, TypeVar, Union 

19 

20import numpy as np 

21from scipy.stats import rv_discrete 

22 

23Number = Union[float, complex] # noqa: UP007 

24T0 = TypeVar("T0") 

25T1 = TypeVar("T1") 

26 

27 

28class EmpiricalDistribution(Generic[T0]): 

29 """Represents an empirical distribution of values. 

30 

31 Supports methods for combination, marginalization, expectation value, etc. 

32 

33 >>> dist1 = EmpiricalDistribution(Counter({(0, 0): 3, (0, 1): 2, (1, 0): 4, (1, 1): 

34 ... 0})) 

35 >>> dist2 = EmpiricalDistribution(Counter({(0, 0): 1, (0, 1): 0, (1, 0): 2, (1, 1): 

36 ... 1})) 

37 >>> dist1.sample_mean(lambda x : x[0] + 2*x[1]) 

38 0.8888888888888888 

39 >>> dist3 = dist2.condition(lambda x: x[0] == 1) 

40 >>> dist3 

41 EmpiricalDistribution(Counter({(1, 0): 2, (1, 1): 1})) 

42 >>> dist4 = dist1 + dist3 

43 >>> dist4 

44 EmpiricalDistribution(Counter({(1, 0): 6, (0, 0): 3, (0, 1): 2, (1, 1): 1})) 

45 """ 

46 

47 def __init__(self, C: Counter[T0]): 

48 self._C: Counter[T0] = Counter({x: c for x, c in C.items() if c > 0}) 

49 

50 def as_counter(self) -> Counter[T0]: 

51 """Return the distribution as a :py:class:`collections.Counter` object.""" 

52 return self._C 

53 

54 @property 

55 def total(self) -> int: 

56 """Return the total number of observations.""" 

57 return self._C.total() 

58 

59 @property 

60 def support(self) -> set[T0]: 

61 """Return the support of the distribution (set of all observations).""" 

62 return set(self._C.keys()) 

63 

64 def __eq__(self, other: object) -> bool: 

65 """Compare distributions for equality.""" 

66 if not isinstance(other, EmpiricalDistribution): 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true

67 return NotImplemented 

68 return self._C == other._C 

69 

70 def __hash__(self) -> int: 

71 return hash(self._C) 

72 

73 def __repr__(self) -> str: 

74 return f"{self.__class__.__name__}({self._C!r})" 

75 

76 def __getitem__(self, x: T0) -> int: 

77 """Get the count associated with an observation.""" 

78 return self._C[x] 

79 

80 def __add__( 

81 self, other: "EmpiricalDistribution[T0]" 

82 ) -> "EmpiricalDistribution[T0]": 

83 """Combine two distributions.""" 

84 return EmpiricalDistribution(self._C + other._C) 

85 

86 def condition(self, criterion: Callable[[T0], bool]) -> "EmpiricalDistribution[T0]": 

87 """Return a new distribution conditioned on the given criterion. 

88 

89 :param criterion: A boolean function defined on all possible observations. 

90 """ 

91 return EmpiricalDistribution( 

92 Counter({x: c for x, c in self._C.items() if criterion(x)}) 

93 ) 

94 

95 def map(self, mapping: Callable[[T0], T1]) -> "EmpiricalDistribution[T1]": 

96 """Return a distribution over a transformed domain. 

97 

98 The provided function maps elements in the original domain to new elements. If 

99 it is not injective, counts are combined. 

100 

101 :param mapping: A function defined on all possible observations, mapping them 

102 to another domain. 

103 """ 

104 C: Counter[T1] = Counter() 

105 for x, c in self._C.items(): 

106 C[mapping(x)] += c 

107 return EmpiricalDistribution(C) 

108 

109 def sample_mean(self, f: Callable[[T0], Number]) -> Number: 

110 """Compute the sample mean of a functional. 

111 

112 The provided function maps observations to numerical values. 

113 

114 :return: Estimate of the mean of the functional based on the observations.""" 

115 return sum(c * f(x) for x, c in self._C.items()) / self.total 

116 

117 def sample_variance(self, f: Callable[[T0], Number]) -> Number: 

118 """Compute the sample variance of a functional. 

119 

120 The provided function maps observations to numerical values. 

121 

122 The sample variance is an unbiased estimate of the variance of the underlying 

123 distribution. 

124 

125 :return: Estimate of the variance of the functional based on the 

126 observations.""" 

127 if self.total < 2: # noqa: PLR2004 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 raise RuntimeError( 

129 "At least two samples are required in order to compute the sample " 

130 "variance." 

131 ) 

132 fs = [(f(x), c) for x, c in self._C.items()] 

133 M0 = self.total 

134 M1 = sum(c * v for v, c in fs) 

135 M2 = sum(c * v**2 for v, c in fs) 

136 return (M2 - M1**2 / M0) / (M0 - 1) 

137 

138 

139class ProbabilityDistribution(Generic[T0]): # noqa: PLW1641 

140 """Represents an exact probability distribution. 

141 

142 Supports methods for combination, marginalization, expectation value, etc. May be 

143 derived from an :py:class:`EmpiricalDistribution`. 

144 """ 

145 

146 def __init__(self, P: dict[T0, float], min_p: float = 0.0): 

147 """Initialize with a dictionary of probabilities. 

148 

149 :param P: Dictionary of probabilities. 

150 :param min_p: Optional probability below which to ignore values. Default 

151 0. Distribution is renormalized after removing these values. 

152 

153 The values must be non-negative. If they do not sum to 1, a warning is 

154 emitted; the distribution will contain normalized values. 

155 """ 

156 if any(x < 0 for x in P.values()): 

157 raise ValueError("Distribution contains negative probabilities") 

158 s0 = sum(P.values()) 

159 if np.isclose(s0, 0): 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 raise ValueError("Distribution has zero weight") 

161 if not np.isclose(s0, 1): 

162 warnings.warn( 

163 "Probabilities used to initialize ProbabilityDistribution do " 

164 "not sum to 1: renormalizing.", 

165 stacklevel=2, 

166 ) 

167 newP = {x: p for x, p in P.items() if p > min_p} 

168 s = sum(newP.values()) 

169 self._P = {x: p / s for x, p in newP.items()} 

170 

171 def as_dict(self) -> dict[T0, float]: 

172 """Return the distribution as a :py:class:`dict` object.""" 

173 return self._P 

174 

175 def as_rv_discrete(self) -> tuple[rv_discrete, list[T0]]: 

176 """Return the distribution as a :py:class:`scipy.stats.rv_discrete` object. 

177 

178 This method returns an RV over integers {0, 1, ..., k-1} where k is the size of 

179 the support, and a list whose i'th member is the item corresponding to the value 

180 i of the RV. 

181 """ 

182 X = list(self._P.keys()) 

183 return (rv_discrete(values=(range(len(X)), [self._P[x] for x in X])), X) 

184 

185 @property 

186 def support(self) -> set[T0]: 

187 """Return the support of the distribution (set of all possible outcomes).""" 

188 return set(self._P.keys()) 

189 

190 def __eq__(self, other: object) -> bool: 

191 """Compare distributions for equality.""" 

192 if not isinstance(other, ProbabilityDistribution): 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true

193 return NotImplemented 

194 keys0 = frozenset(self._P.keys()) 

195 keys1 = frozenset(other._P.keys()) 

196 if keys0 != keys1: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 return False 

198 return all(np.isclose(self._P[x], other._P[x]) for x in keys0) # noqa: SLF001 

199 

200 def __repr__(self) -> str: 

201 return f"{self.__class__.__name__}({self._P!r})" 

202 

203 def __getitem__(self, x: T0) -> float: 

204 """Get the probability associated with a possible outcome.""" 

205 return self._P.get(x, 0.0) 

206 

207 @classmethod 

208 def from_empirical_distribution( 

209 cls, ed: EmpiricalDistribution[T0] 

210 ) -> "ProbabilityDistribution[T0]": 

211 """Estimate a probability distribution from an empirical distribution.""" 

212 S = ed.total 

213 if S == 0: 

214 raise ValueError("Empirical distribution has no values") 

215 f = 1 / S 

216 return cls({x: f * c for x, c in ed.as_counter().items()}) 

217 

218 def condition( 

219 self, criterion: Callable[[T0], bool] 

220 ) -> "ProbabilityDistribution[T0]": 

221 """Return a new distribution conditioned on the given criterion. 

222 

223 :param criterion: A boolean function defined on all possible outcomes. 

224 """ 

225 S = sum(c for x, c in self._P.items() if criterion(x)) 

226 if np.isclose(S, 0): 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true

227 raise ValueError("Condition has probability zero") 

228 f = 1 / S 

229 return ProbabilityDistribution( 

230 {x: f * c for x, c in self._P.items() if criterion(x)} 

231 ) 

232 

233 def map(self, mapping: Callable[[T0], T1]) -> "ProbabilityDistribution[T1]": 

234 """Return a distribution over a transformed domain. 

235 

236 The provided function maps elements in the original domain to new elements. If 

237 it is not injective, probabilities are combined. 

238 

239 :param mapping: A function defined on all possible outcomes, mapping them to 

240 another domain. 

241 """ 

242 P: defaultdict[Any, float] = defaultdict(float) 

243 for x, p in self._P.items(): 

244 P[mapping(x)] += p 

245 return ProbabilityDistribution(P) 

246 

247 def expectation(self, f: Callable[[T0], Number]) -> Number: 

248 """Compute the expectation value of a functional. 

249 

250 The provided function maps possible outcomes to numerical values. 

251 

252 :return: Expectation of the functional. 

253 """ 

254 return sum(p * f(x) for x, p in self._P.items()) 

255 

256 def variance(self, f: Callable[[T0], Number]) -> Number: 

257 """Compute the variance of a functional. 

258 

259 The provided function maps possible outcomes to numerical values. 

260 

261 :return: Variance of the functional. 

262 """ 

263 fs = [(f(x), p) for x, p in self._P.items()] 

264 return sum(p * v**2 for v, p in fs) - (sum(p * v for v, p in fs)) ** 2 

265 

266 

267def convex_combination( 

268 dists: list[tuple[ProbabilityDistribution[T0], float]], 

269) -> ProbabilityDistribution[T0]: 

270 """Return a convex combination of probability distributions. 

271 

272 Each pair in the list comprises a distribution and a weight. The weights must be 

273 non-negative and sum to 1. 

274 

275 >>> dist1 = ProbabilityDistribution({0: 0.25, 1: 0.5, 2: 0.25}) 

276 >>> dist2 = ProbabilityDistribution({0: 0.5, 1: 0.5}) 

277 >>> dist3 = convex_combination([(dist1, 0.25), (dist2, 0.75)]) 

278 >>> dist3 

279 ProbabilityDistribution({0: 0.4375, 1: 0.5, 2: 0.0625}) 

280 >>> dist3.expectation(lambda x : x**2) 

281 0.75 

282 """ 

283 P: defaultdict[T0, float] = defaultdict(float) 

284 S = 0.0 

285 for pd, a in dists: 

286 if a < 0: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true

287 raise ValueError("Weights must be non-negative.") 

288 for x, p in pd._P.items(): # noqa: SLF001 

289 P[x] += a * p 

290 S += a 

291 if not np.isclose(S, 1): 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true

292 raise ValueError("Weights must sum to 1.") 

293 return ProbabilityDistribution(P)