Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/distribution.py: 89%
114 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import warnings
16from collections import Counter, defaultdict
17from collections.abc import Callable
18from typing import Any, Generic, TypeVar, Union
20import numpy as np
21from scipy.stats import rv_discrete
23Number = Union[float, complex]
24T0 = TypeVar("T0")
25T1 = TypeVar("T1")
28class EmpiricalDistribution(Generic[T0]):
29 """Represents an empirical distribution of values.
31 Supports methods for combination, marginalization, expectation value, etc.
33 >>> dist1 = EmpiricalDistribution(Counter({(0, 0): 3, (0, 1): 2, (1, 0): 4, (1, 1):
34 ... 0}))
35 >>> dist2 = EmpiricalDistribution(Counter({(0, 0): 1, (0, 1): 0, (1, 0): 2, (1, 1):
36 ... 1}))
37 >>> dist1.sample_mean(lambda x : x[0] + 2*x[1])
38 0.8888888888888888
39 >>> dist3 = dist2.condition(lambda x: x[0] == 1)
40 >>> dist3
41 EmpiricalDistribution(Counter({(1, 0): 2, (1, 1): 1}))
42 >>> dist4 = dist1 + dist3
43 >>> dist4
44 EmpiricalDistribution(Counter({(1, 0): 6, (0, 0): 3, (0, 1): 2, (1, 1): 1}))
45 """
47 def __init__(self, C: Counter[T0]):
48 self._C: Counter[T0] = Counter({x: c for x, c in C.items() if c > 0})
50 def as_counter(self) -> Counter[T0]:
51 """Return the distribution as a :py:class:`collections.Counter` object."""
52 return self._C
54 @property
55 def total(self) -> int:
56 """Return the total number of observations."""
57 return self._C.total()
59 @property
60 def support(self) -> set[T0]:
61 """Return the support of the distribution (set of all observations)."""
62 return set(self._C.keys())
64 def __eq__(self, other: object) -> bool:
65 """Compare distributions for equality."""
66 if not isinstance(other, EmpiricalDistribution): 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true
67 return NotImplemented
68 return self._C == other._C
70 def __repr__(self) -> str:
71 return f"{self.__class__.__name__}({self._C!r})"
73 def __getitem__(self, x: T0) -> int:
74 """Get the count associated with an observation."""
75 return self._C[x]
77 def __add__(
78 self, other: "EmpiricalDistribution[T0]"
79 ) -> "EmpiricalDistribution[T0]":
80 """Combine two distributions."""
81 return EmpiricalDistribution(self._C + other._C)
83 def condition(self, criterion: Callable[[T0], bool]) -> "EmpiricalDistribution[T0]":
84 """Return a new distribution conditioned on the given criterion.
86 :param criterion: A boolean function defined on all possible observations.
87 """
88 return EmpiricalDistribution(
89 Counter({x: c for x, c in self._C.items() if criterion(x)})
90 )
92 def map(self, mapping: Callable[[T0], T1]) -> "EmpiricalDistribution[T1]":
93 """Return a distribution over a transformed domain.
95 The provided function maps elements in the original domain to new elements. If
96 it is not injective, counts are combined.
98 :param mapping: A function defined on all possible observations, mapping them
99 to another domain.
100 """
101 C: Counter[T1] = Counter()
102 for x, c in self._C.items():
103 C[mapping(x)] += c
104 return EmpiricalDistribution(C)
106 def sample_mean(self, f: Callable[[T0], Number]) -> Number:
107 """Compute the sample mean of a functional.
109 The provided function maps observations to numerical values.
111 :return: Estimate of the mean of the functional based on the observations."""
112 return sum(c * f(x) for x, c in self._C.items()) / self.total
114 def sample_variance(self, f: Callable[[T0], Number]) -> Number:
115 """Compute the sample variance of a functional.
117 The provided function maps observations to numerical values.
119 The sample variance is an unbiased estimate of the variance of the underlying
120 distribution.
122 :return: Estimate of the variance of the functional based on the
123 observations."""
124 if self.total < 2: 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true
125 raise RuntimeError(
126 "At least two samples are required in order to compute the sample "
127 "variance."
128 )
129 fs = [(f(x), c) for x, c in self._C.items()]
130 M0 = self.total
131 M1 = sum(c * v for v, c in fs)
132 M2 = sum(c * v**2 for v, c in fs)
133 return (M2 - M1**2 / M0) / (M0 - 1)
136class ProbabilityDistribution(Generic[T0]):
137 """Represents an exact probability distribution.
139 Supports methods for combination, marginalization, expectation value, etc. May be
140 derived from an :py:class:`EmpriricalDistribution`.
141 """
143 def __init__(self, P: dict[T0, float], min_p: float = 0.0):
144 """Initialize with a dictionary of probabilities.
146 :param P: Dictionary of probabilities.
147 :param min_p: Optional probability below which to ignore values. Default
148 0. Distribution is renormalized after removing these values.
150 The values must be non-negative. If they do not sum to 1, a warning is
151 emitted; the distribution will contain normalized values.
152 """
153 if any(x < 0 for x in P.values()):
154 raise ValueError("Distribution contains negative probabilities")
155 s0 = sum(P.values())
156 if np.isclose(s0, 0): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 raise ValueError("Distribution has zero weight")
158 if not np.isclose(s0, 1):
159 warnings.warn(
160 "Probabilities used to initialize ProbabilityDistribution do "
161 "not sum to 1: renormalizing."
162 )
163 newP = {x: p for x, p in P.items() if p > min_p}
164 s = sum(newP.values())
165 self._P = {x: p / s for x, p in newP.items()}
167 def as_dict(self) -> dict[T0, float]:
168 """Return the distribution as a :py:class:`dict` object."""
169 return self._P
171 def as_rv_discrete(self) -> tuple[rv_discrete, list[T0]]:
172 """Return the distribution as a :py:class:`scipy.stats.rv_discrete` object.
174 This method returns an RV over integers {0, 1, ..., k-1} where k is the size of
175 the support, and a list whose i'th member is the item corresponding to the value
176 i of the RV.
177 """
178 X = list(self._P.keys())
179 return (rv_discrete(values=(range(len(X)), [self._P[x] for x in X])), X)
181 @property
182 def support(self) -> set[T0]:
183 """Return the support of the distribution (set of all possible outcomes)."""
184 return set(self._P.keys())
186 def __eq__(self, other: object) -> bool:
187 """Compare distributions for equality."""
188 if not isinstance(other, ProbabilityDistribution): 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 return NotImplemented
190 keys0 = frozenset(self._P.keys())
191 keys1 = frozenset(other._P.keys())
192 if keys0 != keys1: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 return False
194 return all(np.isclose(self._P[x], other._P[x]) for x in keys0)
196 def __repr__(self) -> str:
197 return f"{self.__class__.__name__}({self._P!r})"
199 def __getitem__(self, x: T0) -> float:
200 """Get the probability associated with a possible outcome."""
201 return self._P.get(x, 0.0)
203 @classmethod
204 def from_empirical_distribution(
205 cls, ed: EmpiricalDistribution[T0]
206 ) -> "ProbabilityDistribution[T0]":
207 """Estimate a probability distribution from an empirical distribution."""
208 S = ed.total
209 if S == 0:
210 raise ValueError("Empirical distribution has no values")
211 f = 1 / S
212 return cls({x: f * c for x, c in ed.as_counter().items()})
214 def condition(
215 self, criterion: Callable[[T0], bool]
216 ) -> "ProbabilityDistribution[T0]":
217 """Return a new distribution conditioned on the given criterion.
219 :param criterion: A boolean function defined on all possible outcomes.
220 """
221 S = sum(c for x, c in self._P.items() if criterion(x))
222 if np.isclose(S, 0): 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true
223 raise ValueError("Condition has probability zero")
224 f = 1 / S
225 return ProbabilityDistribution(
226 {x: f * c for x, c in self._P.items() if criterion(x)}
227 )
229 def map(self, mapping: Callable[[T0], T1]) -> "ProbabilityDistribution[T1]":
230 """Return a distribution over a transformed domain.
232 The provided function maps elements in the original domain to new elements. If
233 it is not injective, probabilities are combined.
235 :param mapping: A function defined on all possible outcomes, mapping them to
236 another domain.
237 """
238 P: defaultdict[Any, float] = defaultdict(float)
239 for x, p in self._P.items():
240 P[mapping(x)] += p
241 return ProbabilityDistribution(P)
243 def expectation(self, f: Callable[[T0], Number]) -> Number:
244 """Compute the expectation value of a functional.
246 The provided function maps possible outcomes to numerical values.
248 :return: Expectation of the functional.
249 """
250 return sum(p * f(x) for x, p in self._P.items())
252 def variance(self, f: Callable[[T0], Number]) -> Number:
253 """Compute the variance of a functional.
255 The provided function maps possible outcomes to numerical values.
257 :return: Variance of the functional.
258 """
259 fs = [(f(x), p) for x, p in self._P.items()]
260 return sum(p * v**2 for v, p in fs) - (sum(p * v for v, p in fs)) ** 2
263def convex_combination(
264 dists: list[tuple[ProbabilityDistribution[T0], float]],
265) -> ProbabilityDistribution[T0]:
266 """Return a convex combination of probability distributions.
268 Each pair in the list comprises a distribution and a weight. The weights must be
269 non-negative and sum to 1.
271 >>> dist1 = ProbabilityDistribution({0: 0.25, 1: 0.5, 2: 0.25})
272 >>> dist2 = ProbabilityDistribution({0: 0.5, 1: 0.5})
273 >>> dist3 = convex_combination([(dist1, 0.25), (dist2, 0.75)])
274 >>> dist3
275 ProbabilityDistribution({0: 0.4375, 1: 0.5, 2: 0.0625})
276 >>> dist3.expectation(lambda x : x**2)
277 0.75
278 """
279 P: defaultdict[T0, float] = defaultdict(float)
280 S = 0.0
281 for pd, a in dists:
282 if a < 0: 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true
283 raise ValueError("Weights must be non-negative.")
284 for x, p in pd._P.items():
285 P[x] += a * p
286 S += a
287 if not np.isclose(S, 1): 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true
288 raise ValueError("Weights must sum to 1.")
289 return ProbabilityDistribution(P)