Skip to content

Commit 43aac29

Browse files
authored
bpo-46257: Convert statistics._ss() to a single pass algorithm (GH-30403)
1 parent 46e4c25 commit 43aac29

File tree

2 files changed

+47
-57
lines changed

2 files changed

+47
-57
lines changed

Lib/statistics.py

Lines changed: 43 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138
from bisect import bisect_left, bisect_right
139139
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
140140
from operator import mul
141-
from collections import Counter, namedtuple
141+
from collections import Counter, namedtuple, defaultdict
142142

143143
_SQRT2 = sqrt(2.0)
144144

@@ -202,6 +202,43 @@ def _sum(data):
202202
return (T, total, count)
203203

204204

205+
def _ss(data, c=None):
206+
"""Return sum of square deviations of sequence data.
207+
208+
If ``c`` is None, the mean is calculated in one pass, and the deviations
209+
from the mean are calculated in a second pass. Otherwise, deviations are
210+
calculated from ``c`` as given. Use the second case with care, as it can
211+
lead to garbage results.
212+
"""
213+
if c is not None:
214+
T, total, count = _sum((d := x - c) * d for x in data)
215+
return (T, total, count)
216+
count = 0
217+
sx_partials = defaultdict(int)
218+
sxx_partials = defaultdict(int)
219+
T = int
220+
for typ, values in groupby(data, type):
221+
T = _coerce(T, typ) # or raise TypeError
222+
for n, d in map(_exact_ratio, values):
223+
count += 1
224+
sx_partials[d] += n
225+
sxx_partials[d] += n * n
226+
if not count:
227+
total = Fraction(0)
228+
elif None in sx_partials:
229+
# The sum will be a NAN or INF. We can ignore all the finite
230+
# partials, and just look at this special one.
231+
total = sx_partials[None]
232+
assert not _isfinite(total)
233+
else:
234+
sx = sum(Fraction(n, d) for d, n in sx_partials.items())
235+
sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
236+
# This formula has poor numeric properties for floats,
237+
# but with fractions it is exact.
238+
total = (count * sxx - sx * sx) / count
239+
return (T, total, count)
240+
241+
205242
def _isfinite(x):
206243
try:
207244
return x.is_finite() # Likely a Decimal.
@@ -399,13 +436,9 @@ def mean(data):
399436
400437
If ``data`` is empty, StatisticsError will be raised.
401438
"""
402-
if iter(data) is data:
403-
data = list(data)
404-
n = len(data)
439+
T, total, n = _sum(data)
405440
if n < 1:
406441
raise StatisticsError('mean requires at least one data point')
407-
T, total, count = _sum(data)
408-
assert count == n
409442
return _convert(total / n, T)
410443

411444

@@ -776,41 +809,6 @@ def quantiles(data, *, n=4, method='exclusive'):
776809

777810
# See http://mathworld.wolfram.com/Variance.html
778811
# http://mathworld.wolfram.com/SampleVariance.html
779-
# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
780-
#
781-
# Under no circumstances use the so-called "computational formula for
782-
# variance", as that is only suitable for hand calculations with a small
783-
# amount of low-precision data. It has terrible numeric properties.
784-
#
785-
# See a comparison of three computational methods here:
786-
# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
787-
788-
def _ss(data, c=None):
789-
"""Return sum of square deviations of sequence data.
790-
791-
If ``c`` is None, the mean is calculated in one pass, and the deviations
792-
from the mean are calculated in a second pass. Otherwise, deviations are
793-
calculated from ``c`` as given. Use the second case with care, as it can
794-
lead to garbage results.
795-
"""
796-
if c is not None:
797-
T, total, count = _sum((d := x - c) * d for x in data)
798-
return (T, total)
799-
T, total, count = _sum(data)
800-
mean_n, mean_d = (total / count).as_integer_ratio()
801-
partials = Counter()
802-
for n, d in map(_exact_ratio, data):
803-
diff_n = n * mean_d - d * mean_n
804-
diff_d = d * mean_d
805-
partials[diff_d * diff_d] += diff_n * diff_n
806-
if None in partials:
807-
# The sum will be a NAN or INF. We can ignore all the finite
808-
# partials, and just look at this special one.
809-
total = partials[None]
810-
assert not _isfinite(total)
811-
else:
812-
total = sum(Fraction(n, d) for d, n in partials.items())
813-
return (T, total)
814812

815813

816814
def variance(data, xbar=None):
@@ -851,12 +849,9 @@ def variance(data, xbar=None):
851849
Fraction(67, 108)
852850
853851
"""
854-
if iter(data) is data:
855-
data = list(data)
856-
n = len(data)
852+
T, ss, n = _ss(data, xbar)
857853
if n < 2:
858854
raise StatisticsError('variance requires at least two data points')
859-
T, ss = _ss(data, xbar)
860855
return _convert(ss / (n - 1), T)
861856

862857

@@ -895,12 +890,9 @@ def pvariance(data, mu=None):
895890
Fraction(13, 72)
896891
897892
"""
898-
if iter(data) is data:
899-
data = list(data)
900-
n = len(data)
893+
T, ss, n = _ss(data, mu)
901894
if n < 1:
902895
raise StatisticsError('pvariance requires at least one data point')
903-
T, ss = _ss(data, mu)
904896
return _convert(ss / n, T)
905897

906898

@@ -913,12 +905,9 @@ def stdev(data, xbar=None):
913905
1.0810874155219827
914906
915907
"""
916-
if iter(data) is data:
917-
data = list(data)
918-
n = len(data)
908+
T, ss, n = _ss(data, xbar)
919909
if n < 2:
920910
raise StatisticsError('stdev requires at least two data points')
921-
T, ss = _ss(data, xbar)
922911
mss = ss / (n - 1)
923912
if issubclass(T, Decimal):
924913
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
@@ -934,12 +923,9 @@ def pstdev(data, mu=None):
934923
0.986893273527251
935924
936925
"""
937-
if iter(data) is data:
938-
data = list(data)
939-
n = len(data)
926+
T, ss, n = _ss(data, mu)
940927
if n < 1:
941928
raise StatisticsError('pstdev requires at least one data point')
942-
T, ss = _ss(data, mu)
943929
mss = ss / n
944930
if issubclass(T, Decimal):
945931
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Optimized the mean, variance, and stdev functions in the statistics module.
2+
If the input is an iterator, it is consumed in a single pass rather than
3+
eating memory by conversion to a list. The single pass algorithm is about
4+
twice as fast as the previous two pass code.

0 commit comments

Comments
 (0)