Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use num::{Float, Num};
use num::{Float, Integer, Num};
use std::hash::Hash;
use std::ops::{Add, Div, Mul, Sub};

use crate::axes::Axes;
Expand Down Expand Up @@ -322,6 +323,62 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
}
}

impl<T: Integer + PartialOrd + Eq + Hash + Copy> Tensor<T> {
pub fn mode(&self, axes: Axes) -> Tensor<T> {
use std::collections::HashMap;

let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
let remaining_axes = all_axes
.clone()
.into_iter()
.filter(|&i| !axes.contains(&i))
.collect::<Vec<_>>();
let remaining_dims = remaining_axes
.iter()
.map(|&i| self.shape[i])
.collect::<Vec<_>>();
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();

// We resolve to a scalar value
if axes.is_empty() || remaining_dims.is_empty() {
let mut frequency_map = HashMap::new();
for &value in &self.data {
*frequency_map.entry(value).or_insert(0) += 1;
}
let mut frequency_vec: Vec<(T, usize)> = frequency_map.into_iter().collect();
frequency_vec.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
let mode: T = frequency_vec.into_iter().next().unwrap().0;
return Tensor::new(&Shape::new(vec![1]).unwrap(), &[mode]).unwrap();
}

// Create new tensor with right shape
let new_shape = Shape::new(remaining_dims).unwrap();
let remove_shape = Shape::new(removing_dims).unwrap();
let mut t: Tensor<T> = Tensor::zeros(&new_shape);

for target in IndexIterator::new(&new_shape) {
let mut frequency_map = HashMap::new();
let mode_iter = IndexIterator::new(&remove_shape);
for mode_index in mode_iter {
let mut indices = target.clone();
for (i, &axis) in axes.iter().enumerate() {
indices = indices.insert(axis, mode_index[i]);
}

let value = self.get(&indices).unwrap();
*frequency_map.entry(*value).or_insert(0) += 1;
}

let mut frequency_vec: Vec<(T, usize)> = frequency_map.into_iter().collect();
frequency_vec.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
let mode: T = frequency_vec.into_iter().next().unwrap().0;
let _ = t.set(&target, mode);
}

t
}
}

impl<T: Float + PartialOrd + Copy> Tensor<T> {
pub fn pow(&self, power: T) -> Tensor<T> {
let mut result = Tensor::zeros(&self.shape);
Expand Down Expand Up @@ -985,6 +1042,41 @@ mod tests {
assert_eq!(result.data, DynamicStorage::new(vec![-10.0, -8.0, -12.0]));
}

#[test]
fn test_tensor_mode_no_axis_1d() {
let shape = shape![5].unwrap();
let data = vec![1, 2, 2, 3, 3];
let tensor = Tensor::new(&shape, &data).unwrap();
let result = tensor.mode(vec![]);

assert_eq!(result.shape(), &shape![1].unwrap());
assert_eq!(result.data, DynamicStorage::new(vec![2]));
}

#[test]
fn test_tensor_mode_one_axis_2d() {
let shape = shape![2, 3].unwrap();
let data = vec![1, 2, 2, 3, 3, 3];
let tensor = Tensor::new(&shape, &data).unwrap();

let result = tensor.mode(vec![0]);

assert_eq!(result.shape(), &shape![3].unwrap());
assert_eq!(result.data, DynamicStorage::new(vec![1, 2, 2]));
}

#[test]
fn test_tensor_mode_multiple_axes_3d() {
let shape = shape![2, 2, 3].unwrap();
let data = vec![1, 2, 2, 3, 3, 3, 1, 2, 2, 3, 3, 3];
let tensor = Tensor::new(&shape, &data).unwrap();

let result = tensor.mode(vec![0, 1]);

assert_eq!(result.shape(), &shape![3].unwrap());
assert_eq!(result.data, DynamicStorage::new(vec![1, 2, 2]));
}

#[test]
fn test_tensor_prod_1d_1d() {
let shape1 = shape![3].unwrap();
Expand Down