diff --git a/Cargo.toml b/Cargo.toml
index 57795ce8..645c41c0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -17,6 +17,7 @@ categories = ["data-structures", "science"]
[dependencies]
ndarray = "0.12.1"
noisy_float = "0.1.8"
+num-integer = "0.1"
num-traits = "0.2"
rand = "0.6"
itertools = { version = "0.7.0", default-features = false }
diff --git a/src/lib.rs b/src/lib.rs
index 9cf586f1..37499368 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -28,6 +28,7 @@
extern crate itertools;
extern crate ndarray;
extern crate noisy_float;
+extern crate num_integer;
extern crate num_traits;
extern crate rand;
diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs
index a2b5bcd6..1609b531 100644
--- a/src/summary_statistics/means.rs
+++ b/src/summary_statistics/means.rs
@@ -1,6 +1,7 @@
use super::SummaryStatisticsExt;
use ndarray::{ArrayBase, Data, Dimension};
-use num_traits::{Float, FromPrimitive, Zero};
+use num_integer::IterBinomial;
+use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
use std::ops::{Add, Div};
impl SummaryStatisticsExt for ArrayBase
@@ -40,37 +41,37 @@ where
where
A: Float + FromPrimitive,
{
- let central_moments = self.central_moments(4);
- central_moments.map(|moments| moments[4] / moments[2].powi(2))
+ let central_moments = self.central_moments(4)?;
+ Some(central_moments[4] / central_moments[2].powi(2))
}
fn skewness(&self) -> Option
where
A: Float + FromPrimitive,
{
- let central_moments = self.central_moments(3);
- central_moments.map(|moments| moments[3] / moments[2].sqrt().powi(3))
+ let central_moments = self.central_moments(3)?;
+ Some(central_moments[3] / central_moments[2].sqrt().powi(3))
}
fn central_moment(&self, order: usize) -> Option
where
A: Float + FromPrimitive,
{
- let mean = self.mean();
- match mean {
- None => None,
- Some(mean) => match order {
- 0 => Some(A::one()),
- 1 => Some(A::zero()),
- n => {
- let shifted_array = self.map(|x| x.clone() - mean);
- let shifted_moments = moments(shifted_array, n);
- let correction_term = -shifted_moments[1].clone();
+ if self.is_empty() {
+ return None;
+ }
+ match order {
+ 0 => Some(A::one()),
+ 1 => Some(A::zero()),
+ n => {
+ let mean = self.mean().unwrap();
+ let shifted_array = self.mapv(|x| x - mean);
+ let shifted_moments = moments(shifted_array, n);
+ let correction_term = -shifted_moments[1];
- let coefficients = central_moment_coefficients(&shifted_moments);
- Some(horner_method(coefficients, correction_term))
- }
- },
+ let coefficients = central_moment_coefficients(&shifted_moments);
+ Some(horner_method(coefficients, correction_term))
+ }
}
}
@@ -78,29 +79,27 @@ where
where
A: Float + FromPrimitive,
{
- let mean = self.mean();
- match mean {
- None => None,
- Some(mean) => {
- match order {
- 0 => Some(vec![A::one()]),
- 1 => Some(vec![A::one(), A::zero()]),
- n => {
- // We only perform this operations once, and then reuse their
- // result to compute all the required moments
- let shifted_array = self.map(|x| x.clone() - mean);
- let shifted_moments = moments(shifted_array, n);
- let correction_term = -shifted_moments[1].clone();
+ if self.is_empty() {
+ return None;
+ }
+ match order {
+ 0 => Some(vec![A::one()]),
+ 1 => Some(vec![A::one(), A::zero()]),
+ n => {
+ // We only perform these operations once, and then reuse their
+ // result to compute all the required moments
+ let mean = self.mean().unwrap();
+ let shifted_array = self.mapv(|x| x - mean);
+ let shifted_moments = moments(shifted_array, n);
+ let correction_term = -shifted_moments[1];
- let mut central_moments = vec![A::one(), A::zero()];
- for k in 2..=n {
- let coefficients = central_moment_coefficients(&shifted_moments[..=k]);
- let central_moment = horner_method(coefficients, correction_term);
- central_moments.push(central_moment)
- }
- Some(central_moments)
- }
+ let mut central_moments = vec![A::one(), A::zero()];
+ for k in 2..=n {
+ let coefficients = central_moment_coefficients(&shifted_moments[..=k]);
+ let central_moment = horner_method(coefficients, correction_term);
+ central_moments.push(central_moment)
}
+ Some(central_moments)
}
}
}
@@ -126,6 +125,9 @@ where
{
let n_elements =
A::from_usize(a.len()).expect("Converting number of elements to `A` must not fail");
+ let order = order
+ .to_i32()
+ .expect("Moment order must not overflow `i32`.");
// When k=0, we are raising each element to the 0th power
// No need to waste CPU cycles going through the array
@@ -137,7 +139,7 @@ where
}
for k in 2..=order {
- moments.push(a.map(|x| x.powi(k as i32)).sum() / n_elements)
+ moments.push(a.map(|x| x.powi(k)).sum() / n_elements)
}
moments
}
@@ -152,34 +154,12 @@ where
A: Float + FromPrimitive,
{
let order = moments.len();
- moments
- .iter()
- .rev()
- .enumerate()
- .map(|(k, moment)| A::from_usize(binomial_coefficient(order, k)).unwrap() * *moment)
+ IterBinomial::new(order)
+ .zip(moments.iter().rev())
+ .map(|(binom, &moment)| A::from_usize(binom).unwrap() * moment)
.collect()
}
-/// Returns the binomial coefficient "n over k".
-///
-/// **Panics** if k > n.
-fn binomial_coefficient(n: usize, k: usize) -> usize {
- if k > n {
- panic!(
- "Tried to compute the binomial coefficient of {0} over {1}, \
- but {1} is strictly greater than {0}!"
- )
- }
- // BC(n, k) = BC(n, n-k)
- let k = if k > n - k { n - k } else { k };
- let mut result = 1;
- for i in 0..k {
- result = result * (n - i);
- result = result / (i + 1);
- }
- result
-}
-
/// Uses [Horner's method] to evaluate a polynomial with a single indeterminate.
///
/// Coefficients are expected to be sorted by ascending order
@@ -270,14 +250,16 @@ mod tests {
}
#[test]
- fn test_central_order_moment_with_empty_array_of_floats() {
+ fn test_central_moment_with_empty_array_of_floats() {
let a: Array1 = array![];
- assert!(a.central_moment(1).is_none());
- assert!(a.central_moments(1).is_none());
+ for order in 0..=3 {
+ assert!(a.central_moment(order).is_none());
+ assert!(a.central_moments(order).is_none());
+ }
}
#[test]
- fn test_zeroth_central_order_moment_is_one() {
+ fn test_zeroth_central_moment_is_one() {
let n = 50;
let bound: f64 = 200.;
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
@@ -285,7 +267,7 @@ mod tests {
}
#[test]
- fn test_first_central_order_moment_is_zero() {
+ fn test_first_central_moment_is_zero() {
let n = 50;
let bound: f64 = 200.;
let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()));
@@ -293,7 +275,7 @@ mod tests {
}
#[test]
- fn test_central_order_moments() {
+ fn test_central_moments() {
let a: Array1 = array![
0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261,
0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832,
@@ -324,7 +306,7 @@ mod tests {
}
#[test]
- fn test_bulk_central_order_moments() {
+ fn test_bulk_central_moments() {
// Test that the bulk method is coherent with the non-bulk method
let n = 50;
let bound: f64 = 200.;
diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs
index ff873d3b..a5664a0a 100644
--- a/src/summary_statistics/mod.rs
+++ b/src/summary_statistics/mod.rs
@@ -113,7 +113,8 @@ where
/// The *p*-th central moment is computed using a corrected two-pass algorithm (see Section 3.5
/// in [Pébay et al., 2016]). Complexity is *O(np)* when *n >> p*, *p > 1*.
///
- /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
+ /// **Panics** if `A::from_usize()` fails to convert the number of elements
+ /// in the array or if `order` overflows `i32`.
///
/// [central moment]: https://en.wikipedia.org/wiki/Central_moment
/// [Pébay et al., 2016]: https://www.osti.gov/pages/servlets/purl/1427275
@@ -130,7 +131,8 @@ where
/// being thus more efficient than repeated calls to [central moment] if the computation
/// of central moments of multiple orders is required.
///
- /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
+ /// **Panics** if `A::from_usize()` fails to convert the number of elements
+ /// in the array or if `order` overflows `i32`.
///
/// [central moments]: https://en.wikipedia.org/wiki/Central_moment
/// [central moment]: #tymethod.central_moment