diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 14485fc79..9e229b46a 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::ops::{Add, Div}; +use std::ops::{Add, Div, Mul}; use libnum::{self, One, Zero, Float}; use itertools::free::enumerate; @@ -33,12 +33,12 @@ impl ArrayBase where A: Clone + Add + libnum::Zero, { if let Some(slc) = self.as_slice_memory_order() { - return numeric_util::unrolled_sum(slc); + return numeric_util::unrolled_fold(slc, A::zero, A::add); } let mut sum = A::zero(); for row in self.inner_rows() { if let Some(slc) = row.as_slice() { - sum = sum + numeric_util::unrolled_sum(slc); + sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add); } else { sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone()); } @@ -46,6 +46,32 @@ impl ArrayBase sum } + /// Return the product of all elements in the array. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let a = arr2(&[[1., 2.], + /// [3., 4.]]); + /// assert_eq!(a.scalar_prod(), 24.); + /// ``` + pub fn scalar_prod(&self) -> A + where A: Clone + Mul + libnum::One, + { + if let Some(slc) = self.as_slice_memory_order() { + return numeric_util::unrolled_fold(slc, A::one, A::mul); + } + let mut sum = A::one(); + for row in self.inner_rows() { + if let Some(slc) = row.as_slice() { + sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul); + } else { + sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone()); + } + } + sum + } + /// Return sum along `axis`. /// /// ``` diff --git a/src/numeric_util.rs b/src/numeric_util.rs index a938c6287..9fc7e5cf0 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -5,50 +5,47 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. -use libnum; - use std::cmp; -use std::ops::{ - Add, -}; use LinalgScalar; -/// Compute the sum of the values in `xs` -pub fn unrolled_sum(mut xs: &[A]) -> A - where A: Clone + Add + libnum::Zero, +/// Fold over the manually unrolled `xs` with `f` +pub fn unrolled_fold(mut xs: &[A], init: I, f: F) -> A + where A: Clone, + I: Fn() -> A, + F: Fn(A, A) -> A, { // eightfold unrolled so that floating point can be vectorized // (even with strict floating point accuracy semantics) - let mut sum = A::zero(); + let mut acc = init(); let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = - (A::zero(), A::zero(), A::zero(), A::zero(), - A::zero(), A::zero(), A::zero(), A::zero()); + (init(), init(), init(), init(), + init(), init(), init(), init()); while xs.len() >= 8 { - p0 = p0 + xs[0].clone(); - p1 = p1 + xs[1].clone(); - p2 = p2 + xs[2].clone(); - p3 = p3 + xs[3].clone(); - p4 = p4 + xs[4].clone(); - p5 = p5 + xs[5].clone(); - p6 = p6 + xs[6].clone(); - p7 = p7 + xs[7].clone(); + p0 = f(p0, xs[0].clone()); + p1 = f(p1, xs[1].clone()); + p2 = f(p2, xs[2].clone()); + p3 = f(p3, xs[3].clone()); + p4 = f(p4, xs[4].clone()); + p5 = f(p5, xs[5].clone()); + p6 = f(p6, xs[6].clone()); + p7 = f(p7, xs[7].clone()); xs = &xs[8..]; } - sum = sum.clone() + (p0 + p4); - sum = sum.clone() + (p1 + p5); - sum = sum.clone() + (p2 + p6); - sum = sum.clone() + (p3 + p7); + acc = f(acc.clone(), f(p0, p4)); + acc = f(acc.clone(), f(p1, p5)); + acc = f(acc.clone(), f(p2, p6)); + acc = f(acc.clone(), f(p3, p7)); // make it clear to the optimizer that this loop is short // and can not be autovectorized. for i in 0..xs.len() { if i >= 7 { break; } - sum = sum.clone() + xs[i].clone() + acc = f(acc.clone(), xs[i].clone()) } - sum + acc } /// Compute the dot product. diff --git a/tests/oper.rs b/tests/oper.rs index c49a692dd..1197f284d 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -271,6 +271,26 @@ fn fold_and_sum() { } } +#[test] +fn scalar_prod() { + let a = Array::linspace(0.5, 2., 128).into_shape((8, 16)).unwrap(); + assert_approx_eq(a.fold(1., |acc, &x| acc * x), a.scalar_prod(), 1e-5); + + // test different strides + let max = 8 as Ixs; + for i in 1..max { + for j in 1..max { + let a1 = a.slice(s![..;i, ..;j]); + let mut prod = 1.; + for elt in a1.iter() { + prod *= *elt; + } + assert_approx_eq(a1.fold(1., |acc, &x| acc * x), prod, 1e-5); + assert_approx_eq(prod, a1.scalar_prod(), 1e-5); + } + } +} + fn range_mat(m: Ix, n: Ix) -> Array2 { Array::linspace(0., (m * n) as f32 - 1., m * n).into_shape((m, n)).unwrap() }