Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/distribution.py: 88%
116 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-04 14:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-04 14:20 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import warnings
16from collections import Counter, defaultdict
17from collections.abc import Callable
18from typing import Any, Generic, TypeVar, Union
20import numpy as np
21from scipy.stats import rv_discrete
23Number = Union[float, complex] # noqa: UP007
24T0 = TypeVar("T0")
25T1 = TypeVar("T1")
28class EmpiricalDistribution(Generic[T0]):
29 """Represents an empirical distribution of values.
31 Supports methods for combination, marginalization, expectation value, etc.
33 >>> dist1 = EmpiricalDistribution(Counter({(0, 0): 3, (0, 1): 2, (1, 0): 4, (1, 1):
34 ... 0}))
35 >>> dist2 = EmpiricalDistribution(Counter({(0, 0): 1, (0, 1): 0, (1, 0): 2, (1, 1):
36 ... 1}))
37 >>> dist1.sample_mean(lambda x : x[0] + 2*x[1])
38 0.8888888888888888
39 >>> dist3 = dist2.condition(lambda x: x[0] == 1)
40 >>> dist3
41 EmpiricalDistribution(Counter({(1, 0): 2, (1, 1): 1}))
42 >>> dist4 = dist1 + dist3
43 >>> dist4
44 EmpiricalDistribution(Counter({(1, 0): 6, (0, 0): 3, (0, 1): 2, (1, 1): 1}))
45 """
47 def __init__(self, C: Counter[T0]):
48 self._C: Counter[T0] = Counter({x: c for x, c in C.items() if c > 0})
50 def as_counter(self) -> Counter[T0]:
51 """Return the distribution as a :py:class:`collections.Counter` object."""
52 return self._C
54 @property
55 def total(self) -> int:
56 """Return the total number of observations."""
57 return self._C.total()
59 @property
60 def support(self) -> set[T0]:
61 """Return the support of the distribution (set of all observations)."""
62 return set(self._C.keys())
64 def __eq__(self, other: object) -> bool:
65 """Compare distributions for equality."""
66 if not isinstance(other, EmpiricalDistribution): 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true
67 return NotImplemented
68 return self._C == other._C
70 def __hash__(self) -> int:
71 return hash(self._C)
73 def __repr__(self) -> str:
74 return f"{self.__class__.__name__}({self._C!r})"
76 def __getitem__(self, x: T0) -> int:
77 """Get the count associated with an observation."""
78 return self._C[x]
80 def __add__(
81 self, other: "EmpiricalDistribution[T0]"
82 ) -> "EmpiricalDistribution[T0]":
83 """Combine two distributions."""
84 return EmpiricalDistribution(self._C + other._C)
86 def condition(self, criterion: Callable[[T0], bool]) -> "EmpiricalDistribution[T0]":
87 """Return a new distribution conditioned on the given criterion.
89 :param criterion: A boolean function defined on all possible observations.
90 """
91 return EmpiricalDistribution(
92 Counter({x: c for x, c in self._C.items() if criterion(x)})
93 )
95 def map(self, mapping: Callable[[T0], T1]) -> "EmpiricalDistribution[T1]":
96 """Return a distribution over a transformed domain.
98 The provided function maps elements in the original domain to new elements. If
99 it is not injective, counts are combined.
101 :param mapping: A function defined on all possible observations, mapping them
102 to another domain.
103 """
104 C: Counter[T1] = Counter()
105 for x, c in self._C.items():
106 C[mapping(x)] += c
107 return EmpiricalDistribution(C)
109 def sample_mean(self, f: Callable[[T0], Number]) -> Number:
110 """Compute the sample mean of a functional.
112 The provided function maps observations to numerical values.
114 :return: Estimate of the mean of the functional based on the observations."""
115 return sum(c * f(x) for x, c in self._C.items()) / self.total
117 def sample_variance(self, f: Callable[[T0], Number]) -> Number:
118 """Compute the sample variance of a functional.
120 The provided function maps observations to numerical values.
122 The sample variance is an unbiased estimate of the variance of the underlying
123 distribution.
125 :return: Estimate of the variance of the functional based on the
126 observations."""
127 if self.total < 2: # noqa: PLR2004 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true
128 raise RuntimeError(
129 "At least two samples are required in order to compute the sample "
130 "variance."
131 )
132 fs = [(f(x), c) for x, c in self._C.items()]
133 M0 = self.total
134 M1 = sum(c * v for v, c in fs)
135 M2 = sum(c * v**2 for v, c in fs)
136 return (M2 - M1**2 / M0) / (M0 - 1)
139class ProbabilityDistribution(Generic[T0]): # noqa: PLW1641
140 """Represents an exact probability distribution.
142 Supports methods for combination, marginalization, expectation value, etc. May be
143 derived from an :py:class:`EmpiricalDistribution`.
144 """
146 def __init__(self, P: dict[T0, float], min_p: float = 0.0):
147 """Initialize with a dictionary of probabilities.
149 :param P: Dictionary of probabilities.
150 :param min_p: Optional probability below which to ignore values. Default
151 0. Distribution is renormalized after removing these values.
153 The values must be non-negative. If they do not sum to 1, a warning is
154 emitted; the distribution will contain normalized values.
155 """
156 if any(x < 0 for x in P.values()):
157 raise ValueError("Distribution contains negative probabilities")
158 s0 = sum(P.values())
159 if np.isclose(s0, 0): 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 raise ValueError("Distribution has zero weight")
161 if not np.isclose(s0, 1):
162 warnings.warn(
163 "Probabilities used to initialize ProbabilityDistribution do "
164 "not sum to 1: renormalizing.",
165 stacklevel=2,
166 )
167 newP = {x: p for x, p in P.items() if p > min_p}
168 s = sum(newP.values())
169 self._P = {x: p / s for x, p in newP.items()}
171 def as_dict(self) -> dict[T0, float]:
172 """Return the distribution as a :py:class:`dict` object."""
173 return self._P
175 def as_rv_discrete(self) -> tuple[rv_discrete, list[T0]]:
176 """Return the distribution as a :py:class:`scipy.stats.rv_discrete` object.
178 This method returns an RV over integers {0, 1, ..., k-1} where k is the size of
179 the support, and a list whose i'th member is the item corresponding to the value
180 i of the RV.
181 """
182 X = list(self._P.keys())
183 return (rv_discrete(values=(range(len(X)), [self._P[x] for x in X])), X)
185 @property
186 def support(self) -> set[T0]:
187 """Return the support of the distribution (set of all possible outcomes)."""
188 return set(self._P.keys())
190 def __eq__(self, other: object) -> bool:
191 """Compare distributions for equality."""
192 if not isinstance(other, ProbabilityDistribution): 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 return NotImplemented
194 keys0 = frozenset(self._P.keys())
195 keys1 = frozenset(other._P.keys())
196 if keys0 != keys1: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 return False
198 return all(np.isclose(self._P[x], other._P[x]) for x in keys0) # noqa: SLF001
200 def __repr__(self) -> str:
201 return f"{self.__class__.__name__}({self._P!r})"
203 def __getitem__(self, x: T0) -> float:
204 """Get the probability associated with a possible outcome."""
205 return self._P.get(x, 0.0)
207 @classmethod
208 def from_empirical_distribution(
209 cls, ed: EmpiricalDistribution[T0]
210 ) -> "ProbabilityDistribution[T0]":
211 """Estimate a probability distribution from an empirical distribution."""
212 S = ed.total
213 if S == 0:
214 raise ValueError("Empirical distribution has no values")
215 f = 1 / S
216 return cls({x: f * c for x, c in ed.as_counter().items()})
218 def condition(
219 self, criterion: Callable[[T0], bool]
220 ) -> "ProbabilityDistribution[T0]":
221 """Return a new distribution conditioned on the given criterion.
223 :param criterion: A boolean function defined on all possible outcomes.
224 """
225 S = sum(c for x, c in self._P.items() if criterion(x))
226 if np.isclose(S, 0): 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true
227 raise ValueError("Condition has probability zero")
228 f = 1 / S
229 return ProbabilityDistribution(
230 {x: f * c for x, c in self._P.items() if criterion(x)}
231 )
233 def map(self, mapping: Callable[[T0], T1]) -> "ProbabilityDistribution[T1]":
234 """Return a distribution over a transformed domain.
236 The provided function maps elements in the original domain to new elements. If
237 it is not injective, probabilities are combined.
239 :param mapping: A function defined on all possible outcomes, mapping them to
240 another domain.
241 """
242 P: defaultdict[Any, float] = defaultdict(float)
243 for x, p in self._P.items():
244 P[mapping(x)] += p
245 return ProbabilityDistribution(P)
247 def expectation(self, f: Callable[[T0], Number]) -> Number:
248 """Compute the expectation value of a functional.
250 The provided function maps possible outcomes to numerical values.
252 :return: Expectation of the functional.
253 """
254 return sum(p * f(x) for x, p in self._P.items())
256 def variance(self, f: Callable[[T0], Number]) -> Number:
257 """Compute the variance of a functional.
259 The provided function maps possible outcomes to numerical values.
261 :return: Variance of the functional.
262 """
263 fs = [(f(x), p) for x, p in self._P.items()]
264 return sum(p * v**2 for v, p in fs) - (sum(p * v for v, p in fs)) ** 2
267def convex_combination(
268 dists: list[tuple[ProbabilityDistribution[T0], float]],
269) -> ProbabilityDistribution[T0]:
270 """Return a convex combination of probability distributions.
272 Each pair in the list comprises a distribution and a weight. The weights must be
273 non-negative and sum to 1.
275 >>> dist1 = ProbabilityDistribution({0: 0.25, 1: 0.5, 2: 0.25})
276 >>> dist2 = ProbabilityDistribution({0: 0.5, 1: 0.5})
277 >>> dist3 = convex_combination([(dist1, 0.25), (dist2, 0.75)])
278 >>> dist3
279 ProbabilityDistribution({0: 0.4375, 1: 0.5, 2: 0.0625})
280 >>> dist3.expectation(lambda x : x**2)
281 0.75
282 """
283 P: defaultdict[T0, float] = defaultdict(float)
284 S = 0.0
285 for pd, a in dists:
286 if a < 0: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 raise ValueError("Weights must be non-negative.")
288 for x, p in pd._P.items(): # noqa: SLF001
289 P[x] += a * p
290 S += a
291 if not np.isclose(S, 1): 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 raise ValueError("Weights must sum to 1.")
293 return ProbabilityDistribution(P)