diff --git a/src/tensor.rs b/src/tensor.rs index 595dc26..c0df899 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,4 +1,5 @@ -use num::{Float, Num}; +use num::{Float, Integer, Num}; +use std::hash::Hash; use std::ops::{Add, Div, Mul, Sub}; use crate::axes::Axes; @@ -322,6 +323,62 @@ impl Tensor { } } +impl Tensor { + pub fn mode(&self, axes: Axes) -> Tensor { + use std::collections::HashMap; + + let all_axes = (0..self.shape.order()).collect::>(); + let remaining_axes = all_axes + .clone() + .into_iter() + .filter(|&i| !axes.contains(&i)) + .collect::>(); + let remaining_dims = remaining_axes + .iter() + .map(|&i| self.shape[i]) + .collect::>(); + let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::>(); + + // We resolve to a scalar value + if axes.is_empty() || remaining_dims.is_empty() { + let mut frequency_map = HashMap::new(); + for &value in &self.data { + *frequency_map.entry(value).or_insert(0) += 1; + } + let mut frequency_vec: Vec<(T, usize)> = frequency_map.into_iter().collect(); + frequency_vec.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + let mode: T = frequency_vec.into_iter().next().unwrap().0; + return Tensor::new(&Shape::new(vec![1]).unwrap(), &[mode]).unwrap(); + } + + // Create new tensor with right shape + let new_shape = Shape::new(remaining_dims).unwrap(); + let remove_shape = Shape::new(removing_dims).unwrap(); + let mut t: Tensor = Tensor::zeros(&new_shape); + + for target in IndexIterator::new(&new_shape) { + let mut frequency_map = HashMap::new(); + let mode_iter = IndexIterator::new(&remove_shape); + for mode_index in mode_iter { + let mut indices = target.clone(); + for (i, &axis) in axes.iter().enumerate() { + indices = indices.insert(axis, mode_index[i]); + } + + let value = self.get(&indices).unwrap(); + *frequency_map.entry(*value).or_insert(0) += 1; + } + + let mut frequency_vec: Vec<(T, usize)> = frequency_map.into_iter().collect(); + frequency_vec.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + let mode: T = frequency_vec.into_iter().next().unwrap().0; + let _ = t.set(&target, mode); + } + + t + } +} + impl Tensor { pub fn pow(&self, power: T) -> Tensor { let mut result = Tensor::zeros(&self.shape); @@ -985,6 +1042,41 @@ mod tests { assert_eq!(result.data, DynamicStorage::new(vec![-10.0, -8.0, -12.0])); } + #[test] + fn test_tensor_mode_no_axis_1d() { + let shape = shape![5].unwrap(); + let data = vec![1, 2, 2, 3, 3]; + let tensor = Tensor::new(&shape, &data).unwrap(); + let result = tensor.mode(vec![]); + + assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result.data, DynamicStorage::new(vec![2])); + } + + #[test] + fn test_tensor_mode_one_axis_2d() { + let shape = shape![2, 3].unwrap(); + let data = vec![1, 2, 2, 3, 3, 3]; + let tensor = Tensor::new(&shape, &data).unwrap(); + + let result = tensor.mode(vec![0]); + + assert_eq!(result.shape(), &shape![3].unwrap()); + assert_eq!(result.data, DynamicStorage::new(vec![1, 2, 2])); + } + + #[test] + fn test_tensor_mode_multiple_axes_3d() { + let shape = shape![2, 2, 3].unwrap(); + let data = vec![1, 2, 2, 3, 3, 3, 1, 2, 2, 3, 3, 3]; + let tensor = Tensor::new(&shape, &data).unwrap(); + + let result = tensor.mode(vec![0, 1]); + + assert_eq!(result.shape(), &shape![3].unwrap()); + assert_eq!(result.data, DynamicStorage::new(vec![1, 2, 2])); + } + #[test] fn test_tensor_prod_1d_1d() { let shape1 = shape![3].unwrap();