Skip to content

Commit 94f0950

Browse files
Merge pull request #3 from jturner314/central-moments
Miscellaneous small improvements to central moments
2 parents a701f33 + a8af880 commit 94f0950

File tree

4 files changed

+61
-75
lines changed

4 files changed

+61
-75
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ categories = ["data-structures", "science"]
1717
[dependencies]
1818
ndarray = "0.12.1"
1919
noisy_float = "0.1.8"
20+
num-integer = "0.1"
2021
num-traits = "0.2"
2122
rand = "0.6"
2223
itertools = { version = "0.7.0", default-features = false }

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
extern crate itertools;
2929
extern crate ndarray;
3030
extern crate noisy_float;
31+
extern crate num_integer;
3132
extern crate num_traits;
3233
extern crate rand;
3334

src/summary_statistics/means.rs

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use super::SummaryStatisticsExt;
22
use ndarray::{ArrayBase, Data, Dimension};
3-
use num_traits::{Float, FromPrimitive, Zero};
3+
use num_integer::IterBinomial;
4+
use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
45
use std::ops::{Add, Div};
56

67
impl<A, S, D> SummaryStatisticsExt<A, S, D> for ArrayBase<S, D>
@@ -40,67 +41,65 @@ where
4041
where
4142
A: Float + FromPrimitive,
4243
{
43-
let central_moments = self.central_moments(4);
44-
central_moments.map(|moments| moments[4] / moments[2].powi(2))
44+
let central_moments = self.central_moments(4)?;
45+
Some(central_moments[4] / central_moments[2].powi(2))
4546
}
4647

4748
fn skewness(&self) -> Option<A>
4849
where
4950
A: Float + FromPrimitive,
5051
{
51-
let central_moments = self.central_moments(3);
52-
central_moments.map(|moments| moments[3] / moments[2].sqrt().powi(3))
52+
let central_moments = self.central_moments(3)?;
53+
Some(central_moments[3] / central_moments[2].sqrt().powi(3))
5354
}
5455

5556
fn central_moment(&self, order: usize) -> Option<A>
5657
where
5758
A: Float + FromPrimitive,
5859
{
59-
let mean = self.mean();
60-
match mean {
61-
None => None,
62-
Some(mean) => match order {
63-
0 => Some(A::one()),
64-
1 => Some(A::zero()),
65-
n => {
66-
let shifted_array = self.map(|x| x.clone() - mean);
67-
let shifted_moments = moments(shifted_array, n);
68-
let correction_term = -shifted_moments[1].clone();
60+
if self.is_empty() {
61+
return None;
62+
}
63+
match order {
64+
0 => Some(A::one()),
65+
1 => Some(A::zero()),
66+
n => {
67+
let mean = self.mean().unwrap();
68+
let shifted_array = self.mapv(|x| x - mean);
69+
let shifted_moments = moments(shifted_array, n);
70+
let correction_term = -shifted_moments[1];
6971

70-
let coefficients = central_moment_coefficients(&shifted_moments);
71-
Some(horner_method(coefficients, correction_term))
72-
}
73-
},
72+
let coefficients = central_moment_coefficients(&shifted_moments);
73+
Some(horner_method(coefficients, correction_term))
74+
}
7475
}
7576
}
7677

7778
fn central_moments(&self, order: usize) -> Option<Vec<A>>
7879
where
7980
A: Float + FromPrimitive,
8081
{
81-
let mean = self.mean();
82-
match mean {
83-
None => None,
84-
Some(mean) => {
85-
match order {
86-
0 => Some(vec![A::one()]),
87-
1 => Some(vec![A::one(), A::zero()]),
88-
n => {
89-
// We only perform this operations once, and then reuse their
90-
// result to compute all the required moments
91-
let shifted_array = self.map(|x| x.clone() - mean);
92-
let shifted_moments = moments(shifted_array, n);
93-
let correction_term = -shifted_moments[1].clone();
82+
if self.is_empty() {
83+
return None;
84+
}
85+
match order {
86+
0 => Some(vec![A::one()]),
87+
1 => Some(vec![A::one(), A::zero()]),
88+
n => {
89+
// We only perform these operations once, and then reuse their
90+
// result to compute all the required moments
91+
let mean = self.mean().unwrap();
92+
let shifted_array = self.mapv(|x| x - mean);
93+
let shifted_moments = moments(shifted_array, n);
94+
let correction_term = -shifted_moments[1];
9495

95-
let mut central_moments = vec![A::one(), A::zero()];
96-
for k in 2..=n {
97-
let coefficients = central_moment_coefficients(&shifted_moments[..=k]);
98-
let central_moment = horner_method(coefficients, correction_term);
99-
central_moments.push(central_moment)
100-
}
101-
Some(central_moments)
102-
}
96+
let mut central_moments = vec![A::one(), A::zero()];
97+
for k in 2..=n {
98+
let coefficients = central_moment_coefficients(&shifted_moments[..=k]);
99+
let central_moment = horner_method(coefficients, correction_term);
100+
central_moments.push(central_moment)
103101
}
102+
Some(central_moments)
104103
}
105104
}
106105
}
@@ -126,6 +125,9 @@ where
126125
{
127126
let n_elements =
128127
A::from_usize(a.len()).expect("Converting number of elements to `A` must not fail");
128+
let order = order
129+
.to_i32()
130+
.expect("Moment order must not overflow `i32`.");
129131

130132
// When k=0, we are raising each element to the 0th power
131133
// No need to waste CPU cycles going through the array
@@ -137,7 +139,7 @@ where
137139
}
138140

139141
for k in 2..=order {
140-
moments.push(a.map(|x| x.powi(k as i32)).sum() / n_elements)
142+
moments.push(a.map(|x| x.powi(k)).sum() / n_elements)
141143
}
142144
moments
143145
}
@@ -152,34 +154,12 @@ where
152154
A: Float + FromPrimitive,
153155
{
154156
let order = moments.len();
155-
moments
156-
.iter()
157-
.rev()
158-
.enumerate()
159-
.map(|(k, moment)| A::from_usize(binomial_coefficient(order, k)).unwrap() * *moment)
157+
IterBinomial::new(order)
158+
.zip(moments.iter().rev())
159+
.map(|(binom, &moment)| A::from_usize(binom).unwrap() * moment)
160160
.collect()
161161
}
162162

163-
/// Returns the binomial coefficient "n over k".
164-
///
165-
/// **Panics** if k > n.
166-
fn binomial_coefficient(n: usize, k: usize) -> usize {
167-
if k > n {
168-
panic!(
169-
"Tried to compute the binomial coefficient of {0} over {1}, \
170-
but {1} is strictly greater than {0}!"
171-
)
172-
}
173-
// BC(n, k) = BC(n, n-k)
174-
let k = if k > n - k { n - k } else { k };
175-
let mut result = 1;
176-
for i in 0..k {
177-
result = result * (n - i);
178-
result = result / (i + 1);
179-
}
180-
result
181-
}
182-
183163
/// Uses [Horner's method] to evaluate a polynomial with a single indeterminate.
184164
///
185165
/// Coefficients are expected to be sorted by ascending order
@@ -270,30 +250,32 @@ mod tests {
270250
}
271251

272252
#[test]
273-
fn test_central_order_moment_with_empty_array_of_floats() {
253+
fn test_central_moment_with_empty_array_of_floats() {
274254
let a: Array1<f64> = array![];
275-
assert!(a.central_moment(1).is_none());
276-
assert!(a.central_moments(1).is_none());
255+
for order in 0..=3 {
256+
assert!(a.central_moment(order).is_none());
257+
assert!(a.central_moments(order).is_none());
258+
}
277259
}
278260

279261
#[test]
280-
fn test_zeroth_central_order_moment_is_one() {
262+
fn test_zeroth_central_moment_is_one() {
281263
let n = 50;
282264
let bound: f64 = 200.;
283265
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
284266
assert_eq!(a.central_moment(0).unwrap(), 1.);
285267
}
286268

287269
#[test]
288-
fn test_first_central_order_moment_is_zero() {
270+
fn test_first_central_moment_is_zero() {
289271
let n = 50;
290272
let bound: f64 = 200.;
291273
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
292274
assert_eq!(a.central_moment(1).unwrap(), 0.);
293275
}
294276

295277
#[test]
296-
fn test_central_order_moments() {
278+
fn test_central_moments() {
297279
let a: Array1<f64> = array![
298280
0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261,
299281
0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832,
@@ -324,7 +306,7 @@ mod tests {
324306
}
325307

326308
#[test]
327-
fn test_bulk_central_order_moments() {
309+
fn test_bulk_central_moments() {
328310
// Test that the bulk method is coherent with the non-bulk method
329311
let n = 50;
330312
let bound: f64 = 200.;

src/summary_statistics/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ where
113113
/// The *p*-th central moment is computed using a corrected two-pass algorithm (see Section 3.5
114114
/// in [Pébay et al., 2016]). Complexity is *O(np)* when *n >> p*, *p > 1*.
115115
///
116-
/// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
116+
/// **Panics** if `A::from_usize()` fails to convert the number of elements
117+
/// in the array or if `order` overflows `i32`.
117118
///
118119
/// [central moment]: https://en.wikipedia.org/wiki/Central_moment
119120
/// [Pébay et al., 2016]: https://www.osti.gov/pages/servlets/purl/1427275
@@ -130,7 +131,8 @@ where
130131
/// being thus more efficient than repeated calls to [central moment] if the computation
131132
/// of central moments of multiple orders is required.
132133
///
133-
/// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
134+
/// **Panics** if `A::from_usize()` fails to convert the number of elements
135+
/// in the array or if `order` overflows `i32`.
134136
///
135137
/// [central moments]: https://en.wikipedia.org/wiki/Central_moment
136138
/// [central moment]: #tymethod.central_moment

0 commit comments

Comments
 (0)