Skip to content

Commit 8f95705

Browse files
Alternative implementation for sum_axis
1 parent b3d2b42 commit 8f95705

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

src/numeric/impl_numeric.rs

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,11 @@ impl<A, S, D> ArrayBase<S, D>
103103
where A: Clone + Zero + Add<Output=A>,
104104
D: RemoveAxis,
105105
{
106-
let n = self.len_of(axis);
107-
let stride = self.strides()[axis.index()];
108-
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
109-
if self.ndim() == 2 && stride == 1 {
110-
// contiguous along the axis we are summing
111-
let ax = axis.index();
112-
for (i, elt) in enumerate(&mut res) {
113-
*elt = self.index_axis(Axis(1 - ax), i).sum();
114-
}
115-
res
116-
} else {
117-
numeric_util::array_pairwise_sum(
118-
(0..n).map(|i| self.index_axis(axis, i)),
119-
|| res.clone()
120-
)
121-
}
106+
let mut out = Array::zeros(self.dim.remove_axis(axis));
107+
Zip::from(&mut out)
108+
.and(self.lanes(axis))
109+
.apply(|out, lane| *out = lane.sum());
110+
out
122111
}
123112

124113
/// Return mean along `axis`.

0 commit comments

Comments
 (0)