Coverage for /home/runner/work/tket/tket/pytket/pytket/utils/distribution.py: 89%
114 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:08 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 15:08 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import warnings
16from collections import Counter, defaultdict
17from collections.abc import Callable
18from typing import Any, Generic, TypeVar, Union
20import numpy as np
21from scipy.stats import rv_discrete
23Number = Union[float, complex] # noqa: UP007
24T0 = TypeVar("T0")
25T1 = TypeVar("T1")
28class EmpiricalDistribution(Generic[T0]):
29 """Represents an empirical distribution of values.
31 Supports methods for combination, marginalization, expectation value, etc.
33 >>> dist1 = EmpiricalDistribution(Counter({(0, 0): 3, (0, 1): 2, (1, 0): 4, (1, 1):
34 ... 0}))
35 >>> dist2 = EmpiricalDistribution(Counter({(0, 0): 1, (0, 1): 0, (1, 0): 2, (1, 1):
36 ... 1}))
37 >>> dist1.sample_mean(lambda x : x[0] + 2*x[1])
38 0.8888888888888888
39 >>> dist3 = dist2.condition(lambda x: x[0] == 1)
40 >>> dist3
41 EmpiricalDistribution(Counter({(1, 0): 2, (1, 1): 1}))
42 >>> dist4 = dist1 + dist3
43 >>> dist4
44 EmpiricalDistribution(Counter({(1, 0): 6, (0, 0): 3, (0, 1): 2, (1, 1): 1}))
45 """
47 def __init__(self, C: Counter[T0]):
48 self._C: Counter[T0] = Counter({x: c for x, c in C.items() if c > 0})
50 def as_counter(self) -> Counter[T0]:
51 """Return the distribution as a :py:class:`collections.Counter` object."""
52 return self._C
54 @property
55 def total(self) -> int:
56 """Return the total number of observations."""
57 return self._C.total()
59 @property
60 def support(self) -> set[T0]:
61 """Return the support of the distribution (set of all observations)."""
62 return set(self._C.keys())
64 def __eq__(self, other: object) -> bool:
65 """Compare distributions for equality."""
66 if not isinstance(other, EmpiricalDistribution): 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true
67 return NotImplemented
68 return self._C == other._C
70 def __repr__(self) -> str:
71 return f"{self.__class__.__name__}({self._C!r})"
73 def __getitem__(self, x: T0) -> int:
74 """Get the count associated with an observation."""
75 return self._C[x]
77 def __add__(
78 self, other: "EmpiricalDistribution[T0]"
79 ) -> "EmpiricalDistribution[T0]":
80 """Combine two distributions."""
81 return EmpiricalDistribution(self._C + other._C)
83 def condition(self, criterion: Callable[[T0], bool]) -> "EmpiricalDistribution[T0]":
84 """Return a new distribution conditioned on the given criterion.
86 :param criterion: A boolean function defined on all possible observations.
87 """
88 return EmpiricalDistribution(
89 Counter({x: c for x, c in self._C.items() if criterion(x)})
90 )
92 def map(self, mapping: Callable[[T0], T1]) -> "EmpiricalDistribution[T1]":
93 """Return a distribution over a transformed domain.
95 The provided function maps elements in the original domain to new elements. If
96 it is not injective, counts are combined.
98 :param mapping: A function defined on all possible observations, mapping them
99 to another domain.
100 """
101 C: Counter[T1] = Counter()
102 for x, c in self._C.items():
103 C[mapping(x)] += c
104 return EmpiricalDistribution(C)
106 def sample_mean(self, f: Callable[[T0], Number]) -> Number:
107 """Compute the sample mean of a functional.
109 The provided function maps observations to numerical values.
111 :return: Estimate of the mean of the functional based on the observations."""
112 return sum(c * f(x) for x, c in self._C.items()) / self.total
114 def sample_variance(self, f: Callable[[T0], Number]) -> Number:
115 """Compute the sample variance of a functional.
117 The provided function maps observations to numerical values.
119 The sample variance is an unbiased estimate of the variance of the underlying
120 distribution.
122 :return: Estimate of the variance of the functional based on the
123 observations."""
124 if self.total < 2: # noqa: PLR2004 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true
125 raise RuntimeError(
126 "At least two samples are required in order to compute the sample "
127 "variance."
128 )
129 fs = [(f(x), c) for x, c in self._C.items()]
130 M0 = self.total
131 M1 = sum(c * v for v, c in fs)
132 M2 = sum(c * v**2 for v, c in fs)
133 return (M2 - M1**2 / M0) / (M0 - 1)
136class ProbabilityDistribution(Generic[T0]):
137 """Represents an exact probability distribution.
139 Supports methods for combination, marginalization, expectation value, etc. May be
140 derived from an :py:class:`EmpriricalDistribution`.
141 """
143 def __init__(self, P: dict[T0, float], min_p: float = 0.0):
144 """Initialize with a dictionary of probabilities.
146 :param P: Dictionary of probabilities.
147 :param min_p: Optional probability below which to ignore values. Default
148 0. Distribution is renormalized after removing these values.
150 The values must be non-negative. If they do not sum to 1, a warning is
151 emitted; the distribution will contain normalized values.
152 """
153 if any(x < 0 for x in P.values()):
154 raise ValueError("Distribution contains negative probabilities")
155 s0 = sum(P.values())
156 if np.isclose(s0, 0): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 raise ValueError("Distribution has zero weight")
158 if not np.isclose(s0, 1):
159 warnings.warn(
160 "Probabilities used to initialize ProbabilityDistribution do "
161 "not sum to 1: renormalizing.",
162 stacklevel=2,
163 )
164 newP = {x: p for x, p in P.items() if p > min_p}
165 s = sum(newP.values())
166 self._P = {x: p / s for x, p in newP.items()}
168 def as_dict(self) -> dict[T0, float]:
169 """Return the distribution as a :py:class:`dict` object."""
170 return self._P
172 def as_rv_discrete(self) -> tuple[rv_discrete, list[T0]]:
173 """Return the distribution as a :py:class:`scipy.stats.rv_discrete` object.
175 This method returns an RV over integers {0, 1, ..., k-1} where k is the size of
176 the support, and a list whose i'th member is the item corresponding to the value
177 i of the RV.
178 """
179 X = list(self._P.keys())
180 return (rv_discrete(values=(range(len(X)), [self._P[x] for x in X])), X)
182 @property
183 def support(self) -> set[T0]:
184 """Return the support of the distribution (set of all possible outcomes)."""
185 return set(self._P.keys())
187 def __eq__(self, other: object) -> bool:
188 """Compare distributions for equality."""
189 if not isinstance(other, ProbabilityDistribution): 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true
190 return NotImplemented
191 keys0 = frozenset(self._P.keys())
192 keys1 = frozenset(other._P.keys())
193 if keys0 != keys1: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 return False
195 return all(np.isclose(self._P[x], other._P[x]) for x in keys0) # noqa: SLF001
197 def __repr__(self) -> str:
198 return f"{self.__class__.__name__}({self._P!r})"
200 def __getitem__(self, x: T0) -> float:
201 """Get the probability associated with a possible outcome."""
202 return self._P.get(x, 0.0)
204 @classmethod
205 def from_empirical_distribution(
206 cls, ed: EmpiricalDistribution[T0]
207 ) -> "ProbabilityDistribution[T0]":
208 """Estimate a probability distribution from an empirical distribution."""
209 S = ed.total
210 if S == 0:
211 raise ValueError("Empirical distribution has no values")
212 f = 1 / S
213 return cls({x: f * c for x, c in ed.as_counter().items()})
215 def condition(
216 self, criterion: Callable[[T0], bool]
217 ) -> "ProbabilityDistribution[T0]":
218 """Return a new distribution conditioned on the given criterion.
220 :param criterion: A boolean function defined on all possible outcomes.
221 """
222 S = sum(c for x, c in self._P.items() if criterion(x))
223 if np.isclose(S, 0): 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true
224 raise ValueError("Condition has probability zero")
225 f = 1 / S
226 return ProbabilityDistribution(
227 {x: f * c for x, c in self._P.items() if criterion(x)}
228 )
230 def map(self, mapping: Callable[[T0], T1]) -> "ProbabilityDistribution[T1]":
231 """Return a distribution over a transformed domain.
233 The provided function maps elements in the original domain to new elements. If
234 it is not injective, probabilities are combined.
236 :param mapping: A function defined on all possible outcomes, mapping them to
237 another domain.
238 """
239 P: defaultdict[Any, float] = defaultdict(float)
240 for x, p in self._P.items():
241 P[mapping(x)] += p
242 return ProbabilityDistribution(P)
244 def expectation(self, f: Callable[[T0], Number]) -> Number:
245 """Compute the expectation value of a functional.
247 The provided function maps possible outcomes to numerical values.
249 :return: Expectation of the functional.
250 """
251 return sum(p * f(x) for x, p in self._P.items())
253 def variance(self, f: Callable[[T0], Number]) -> Number:
254 """Compute the variance of a functional.
256 The provided function maps possible outcomes to numerical values.
258 :return: Variance of the functional.
259 """
260 fs = [(f(x), p) for x, p in self._P.items()]
261 return sum(p * v**2 for v, p in fs) - (sum(p * v for v, p in fs)) ** 2
264def convex_combination(
265 dists: list[tuple[ProbabilityDistribution[T0], float]],
266) -> ProbabilityDistribution[T0]:
267 """Return a convex combination of probability distributions.
269 Each pair in the list comprises a distribution and a weight. The weights must be
270 non-negative and sum to 1.
272 >>> dist1 = ProbabilityDistribution({0: 0.25, 1: 0.5, 2: 0.25})
273 >>> dist2 = ProbabilityDistribution({0: 0.5, 1: 0.5})
274 >>> dist3 = convex_combination([(dist1, 0.25), (dist2, 0.75)])
275 >>> dist3
276 ProbabilityDistribution({0: 0.4375, 1: 0.5, 2: 0.0625})
277 >>> dist3.expectation(lambda x : x**2)
278 0.75
279 """
280 P: defaultdict[T0, float] = defaultdict(float)
281 S = 0.0
282 for pd, a in dists:
283 if a < 0: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true
284 raise ValueError("Weights must be non-negative.")
285 for x, p in pd._P.items(): # noqa: SLF001
286 P[x] += a * p
287 S += a
288 if not np.isclose(S, 1): 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true
289 raise ValueError("Weights must sum to 1.")
290 return ProbabilityDistribution(P)