diff --git a/Cargo.toml b/Cargo.toml index 57795ce8..645c41c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ categories = ["data-structures", "science"] [dependencies] ndarray = "0.12.1" noisy_float = "0.1.8" +num-integer = "0.1" num-traits = "0.2" rand = "0.6" itertools = { version = "0.7.0", default-features = false } diff --git a/src/lib.rs b/src/lib.rs index 9cf586f1..37499368 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ extern crate itertools; extern crate ndarray; extern crate noisy_float; +extern crate num_integer; extern crate num_traits; extern crate rand; diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index a2b5bcd6..1609b531 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -1,6 +1,7 @@ use super::SummaryStatisticsExt; use ndarray::{ArrayBase, Data, Dimension}; -use num_traits::{Float, FromPrimitive, Zero}; +use num_integer::IterBinomial; +use num_traits::{Float, FromPrimitive, ToPrimitive, Zero}; use std::ops::{Add, Div}; impl SummaryStatisticsExt for ArrayBase @@ -40,37 +41,37 @@ where where A: Float + FromPrimitive, { - let central_moments = self.central_moments(4); - central_moments.map(|moments| moments[4] / moments[2].powi(2)) + let central_moments = self.central_moments(4)?; + Some(central_moments[4] / central_moments[2].powi(2)) } fn skewness(&self) -> Option where A: Float + FromPrimitive, { - let central_moments = self.central_moments(3); - central_moments.map(|moments| moments[3] / moments[2].sqrt().powi(3)) + let central_moments = self.central_moments(3)?; + Some(central_moments[3] / central_moments[2].sqrt().powi(3)) } fn central_moment(&self, order: usize) -> Option where A: Float + FromPrimitive, { - let mean = self.mean(); - match mean { - None => None, - Some(mean) => match order { - 0 => Some(A::one()), - 1 => Some(A::zero()), - n => { - let shifted_array = self.map(|x| x.clone() - mean); - let shifted_moments = moments(shifted_array, n); - let correction_term = -shifted_moments[1].clone(); + if self.is_empty() { + return None; + } + match order { + 0 => Some(A::one()), + 1 => Some(A::zero()), + n => { + let mean = self.mean().unwrap(); + let shifted_array = self.mapv(|x| x - mean); + let shifted_moments = moments(shifted_array, n); + let correction_term = -shifted_moments[1]; - let coefficients = central_moment_coefficients(&shifted_moments); - Some(horner_method(coefficients, correction_term)) - } - }, + let coefficients = central_moment_coefficients(&shifted_moments); + Some(horner_method(coefficients, correction_term)) + } } } @@ -78,29 +79,27 @@ where where A: Float + FromPrimitive, { - let mean = self.mean(); - match mean { - None => None, - Some(mean) => { - match order { - 0 => Some(vec![A::one()]), - 1 => Some(vec![A::one(), A::zero()]), - n => { - // We only perform this operations once, and then reuse their - // result to compute all the required moments - let shifted_array = self.map(|x| x.clone() - mean); - let shifted_moments = moments(shifted_array, n); - let correction_term = -shifted_moments[1].clone(); + if self.is_empty() { + return None; + } + match order { + 0 => Some(vec![A::one()]), + 1 => Some(vec![A::one(), A::zero()]), + n => { + // We only perform these operations once, and then reuse their + // result to compute all the required moments + let mean = self.mean().unwrap(); + let shifted_array = self.mapv(|x| x - mean); + let shifted_moments = moments(shifted_array, n); + let correction_term = -shifted_moments[1]; - let mut central_moments = vec![A::one(), A::zero()]; - for k in 2..=n { - let coefficients = central_moment_coefficients(&shifted_moments[..=k]); - let central_moment = horner_method(coefficients, correction_term); - central_moments.push(central_moment) - } - Some(central_moments) - } + let mut central_moments = vec![A::one(), A::zero()]; + for k in 2..=n { + let coefficients = central_moment_coefficients(&shifted_moments[..=k]); + let central_moment = horner_method(coefficients, correction_term); + central_moments.push(central_moment) } + Some(central_moments) } } } @@ -126,6 +125,9 @@ where { let n_elements = A::from_usize(a.len()).expect("Converting number of elements to `A` must not fail"); + let order = order + .to_i32() + .expect("Moment order must not overflow `i32`."); // When k=0, we are raising each element to the 0th power // No need to waste CPU cycles going through the array @@ -137,7 +139,7 @@ where } for k in 2..=order { - moments.push(a.map(|x| x.powi(k as i32)).sum() / n_elements) + moments.push(a.map(|x| x.powi(k)).sum() / n_elements) } moments } @@ -152,34 +154,12 @@ where A: Float + FromPrimitive, { let order = moments.len(); - moments - .iter() - .rev() - .enumerate() - .map(|(k, moment)| A::from_usize(binomial_coefficient(order, k)).unwrap() * *moment) + IterBinomial::new(order) + .zip(moments.iter().rev()) + .map(|(binom, &moment)| A::from_usize(binom).unwrap() * moment) .collect() } -/// Returns the binomial coefficient "n over k". -/// -/// **Panics** if k > n. -fn binomial_coefficient(n: usize, k: usize) -> usize { - if k > n { - panic!( - "Tried to compute the binomial coefficient of {0} over {1}, \ - but {1} is strictly greater than {0}!" - ) - } - // BC(n, k) = BC(n, n-k) - let k = if k > n - k { n - k } else { k }; - let mut result = 1; - for i in 0..k { - result = result * (n - i); - result = result / (i + 1); - } - result -} - /// Uses [Horner's method] to evaluate a polynomial with a single indeterminate. /// /// Coefficients are expected to be sorted by ascending order @@ -270,14 +250,16 @@ mod tests { } #[test] - fn test_central_order_moment_with_empty_array_of_floats() { + fn test_central_moment_with_empty_array_of_floats() { let a: Array1 = array![]; - assert!(a.central_moment(1).is_none()); - assert!(a.central_moments(1).is_none()); + for order in 0..=3 { + assert!(a.central_moment(order).is_none()); + assert!(a.central_moments(order).is_none()); + } } #[test] - fn test_zeroth_central_order_moment_is_one() { + fn test_zeroth_central_moment_is_one() { let n = 50; let bound: f64 = 200.; let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); @@ -285,7 +267,7 @@ mod tests { } #[test] - fn test_first_central_order_moment_is_zero() { + fn test_first_central_moment_is_zero() { let n = 50; let bound: f64 = 200.; let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); @@ -293,7 +275,7 @@ mod tests { } #[test] - fn test_central_order_moments() { + fn test_central_moments() { let a: Array1 = array![ 0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261, 0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832, @@ -324,7 +306,7 @@ mod tests { } #[test] - fn test_bulk_central_order_moments() { + fn test_bulk_central_moments() { // Test that the bulk method is coherent with the non-bulk method let n = 50; let bound: f64 = 200.; diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index ff873d3b..a5664a0a 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -113,7 +113,8 @@ where /// The *p*-th central moment is computed using a corrected two-pass algorithm (see Section 3.5 /// in [Pébay et al., 2016]). Complexity is *O(np)* when *n >> p*, *p > 1*. /// - /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. + /// **Panics** if `A::from_usize()` fails to convert the number of elements + /// in the array or if `order` overflows `i32`. /// /// [central moment]: https://en.wikipedia.org/wiki/Central_moment /// [Pébay et al., 2016]: https://www.osti.gov/pages/servlets/purl/1427275 @@ -130,7 +131,8 @@ where /// being thus more efficient than repeated calls to [central moment] if the computation /// of central moments of multiple orders is required. /// - /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. + /// **Panics** if `A::from_usize()` fails to convert the number of elements + /// in the array or if `order` overflows `i32`. /// /// [central moments]: https://en.wikipedia.org/wiki/Central_moment /// [central moment]: #tymethod.central_moment