diff --git a/src/matrix.rs b/src/matrix.rs index 215be90..47dd097 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -10,6 +10,7 @@ use crate::tensor::DynamicTensor; use crate::vector::DynamicVector; use num::{Float, Num}; +#[derive(Debug, PartialEq)] pub struct DynamicMatrix { tensor: DynamicTensor, } @@ -263,11 +264,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data).unwrap(); let matrix = DynamicMatrix::from_tensor(tensor).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 2.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 3.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 4.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &data).unwrap()); } #[test] @@ -283,49 +280,31 @@ mod tests { fn test_fill() { let shape = shape![2, 2].unwrap(); let matrix = DynamicMatrix::fill(&shape, 3.0).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 3.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 3.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 3.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 3.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[3.0; 4]).unwrap()); } #[test] fn test_eye() { let shape = shape![3, 3].unwrap(); let matrix = DynamicMatrix::::eye(&shape).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 0.0); - assert_eq!(matrix[coord![0, 2].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 1.0); - assert_eq!(matrix[coord![1, 2].unwrap()], 0.0); - assert_eq!(matrix[coord![2, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![2, 1].unwrap()], 0.0); - assert_eq!(matrix[coord![2, 2].unwrap()], 1.0); + assert_eq!( + matrix, + DynamicMatrix::new(&shape, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap() + ); } #[test] fn test_zeros() { let shape = shape![2, 2].unwrap(); let matrix = DynamicMatrix::::zeros(&shape).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 0.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[0.0; 4]).unwrap()); } #[test] fn test_ones() { let shape = shape![2, 2].unwrap(); let matrix = DynamicMatrix::::ones(&shape).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 1.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 1.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0; 4]).unwrap()); } #[test] @@ -349,11 +328,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let mut matrix = DynamicMatrix::new(&shape, &data).unwrap(); matrix[coord![1, 0].unwrap()] = 5.0; - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 2.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 5.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 4.0); + assert_eq!( + matrix, + DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap() + ); } #[test] @@ -362,11 +340,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let mut matrix = DynamicMatrix::new(&shape, &data).unwrap(); matrix.set(&coord![1, 0].unwrap(), 5.0).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 2.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 5.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 4.0); + assert_eq!( + matrix, + DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap() + ); } #[test] @@ -375,8 +352,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.sum(vec![0, 1]); - assert_eq!(result[0], 10.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[10.0]).unwrap()); } #[test] @@ -385,8 +361,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.mean(vec![0, 1]); - assert_eq!(result[0], 2.5); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[2.5]).unwrap()); } #[test] @@ -395,8 +370,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.var(vec![0, 1]); - assert_eq!(result[0], 1.25); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.25]).unwrap()); } #[test] @@ -405,8 +379,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.min(vec![0, 1]); - assert_eq!(result[0], 1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.0]).unwrap()); } #[test] @@ -415,8 +388,7 @@ mod tests { let data = vec![-1.0, -2.0, -3.0, -4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.max(vec![0, 1]); - assert_eq!(result[0], -1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[-1.0]).unwrap()); } #[test] @@ -427,11 +399,10 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1.matmul(&matrix2); - assert_eq!(result.shape(), &shape); - assert_eq!(result[coord![0, 0].unwrap()], 10.0); - assert_eq!(result[coord![0, 1].unwrap()], 13.0); - assert_eq!(result[coord![1, 0].unwrap()], 22.0); - assert_eq!(result[coord![1, 1].unwrap()], 29.0); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[10.0, 13.0, 22.0, 29.0]).unwrap() + ); } #[test] @@ -442,9 +413,7 @@ mod tests { let vector_data = vec![1.0, 2.0]; let vector = DynamicVector::new(&vector_data).unwrap(); let result = matrix.vecmul(&vector); - assert_eq!(result.shape(), &shape![2].unwrap()); - assert_eq!(result[0], 5.0); - assert_eq!(result[1], 11.0); + assert_eq!(result, DynamicVector::new(&[5.0, 11.0]).unwrap()); } #[test] @@ -453,11 +422,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix + 2.0; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 4.0); - assert_eq!(result[coord![1, 0].unwrap()], 5.0); - assert_eq!(result[coord![1, 1].unwrap()], 6.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 4.0, 5.0, 6.0]).unwrap() + ); } #[test] @@ -468,11 +436,10 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 + matrix2; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 5.0); - assert_eq!(result[coord![1, 0].unwrap()], 7.0); - assert_eq!(result[coord![1, 1].unwrap()], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -483,11 +450,10 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix + tensor; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 5.0); - assert_eq!(result[coord![1, 0].unwrap()], 7.0); - assert_eq!(result[coord![1, 1].unwrap()], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -496,11 +462,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix - 2.0; - assert_eq!(result[coord![0, 0].unwrap()], -1.0); - assert_eq!(result[coord![0, 1].unwrap()], 0.0); - assert_eq!(result[coord![1, 0].unwrap()], 1.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[-1.0, 0.0, 1.0, 2.0]).unwrap() + ); } #[test] @@ -511,11 +476,7 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 - matrix2; - assert_eq!(result[coord![0, 0].unwrap()], -1.0); - assert_eq!(result[coord![0, 1].unwrap()], -1.0); - assert_eq!(result[coord![1, 0].unwrap()], -1.0); - assert_eq!(result[coord![1, 1].unwrap()], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[-1.0; 4]).unwrap()); } #[test] @@ -526,11 +487,7 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix - tensor; - assert_eq!(result[coord![0, 0].unwrap()], -1.0); - assert_eq!(result[coord![0, 1].unwrap()], -1.0); - assert_eq!(result[coord![1, 0].unwrap()], -1.0); - assert_eq!(result[coord![1, 1].unwrap()], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[-1.0; 4]).unwrap()); } #[test] @@ -539,11 +496,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix * 2.0; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 4.0); - assert_eq!(result[coord![1, 0].unwrap()], 6.0); - assert_eq!(result[coord![1, 1].unwrap()], 8.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap() + ); } #[test] @@ -554,11 +510,10 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 * matrix2; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 6.0); - assert_eq!(result[coord![1, 0].unwrap()], 12.0); - assert_eq!(result[coord![1, 1].unwrap()], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] @@ -569,11 +524,10 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix * tensor; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 6.0); - assert_eq!(result[coord![1, 0].unwrap()], 12.0); - assert_eq!(result[coord![1, 1].unwrap()], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] @@ -582,11 +536,10 @@ mod tests { let data = vec![4.0, 6.0, 8.0, 10.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix / 2.0; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 3.0); - assert_eq!(result[coord![1, 0].unwrap()], 4.0); - assert_eq!(result[coord![1, 1].unwrap()], 5.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap() + ); } #[test] @@ -597,11 +550,7 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 / matrix2; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 2.0); - assert_eq!(result[coord![1, 0].unwrap()], 2.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0; 4]).unwrap()); } #[test] @@ -612,11 +561,7 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix / tensor; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 2.0); - assert_eq!(result[coord![1, 0].unwrap()], 2.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0; 4]).unwrap()); } #[test] diff --git a/src/tensor.rs b/src/tensor.rs index 595dc26..5944e9d 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -11,7 +11,7 @@ use crate::shape::Shape; use crate::storage::DynamicStorage; use crate::vector::DynamicVector; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct DynamicTensor { data: DynamicStorage, shape: Shape, @@ -572,7 +572,6 @@ mod tests { let data = vec![1.0, 2.0, 3.0]; // Mismatched data length let result = Tensor::new(&shape, &data); - assert!(result.is_err()); } @@ -580,34 +579,27 @@ mod tests { fn test_zeros_tensor() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::zeros(&shape); - - assert_eq!(tensor.shape(), &shape); - assert_eq!(tensor.data, DynamicStorage::new(vec![0.0; shape.size()])); + assert_eq!(tensor, DynamicTensor::new(&shape, &[0.0; 6]).unwrap()); } #[test] fn test_ones_tensor() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::ones(&shape); - - assert_eq!(tensor.shape(), &shape); - assert_eq!(tensor.data, DynamicStorage::new(vec![1.0; shape.size()])); + assert_eq!(tensor, DynamicTensor::new(&shape, &[1.0; 6]).unwrap()); } #[test] fn test_fill_tensor() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::fill(&shape, 7.0); - - assert_eq!(tensor.shape(), &shape); - assert_eq!(tensor.data, DynamicStorage::new(vec![7.0; shape.size()])); + assert_eq!(tensor, DynamicTensor::new(&shape, &[7.0; 6]).unwrap()); } #[test] fn test_tensor_shape() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::zeros(&shape); - assert_eq!(tensor.shape(), &shape); } @@ -615,7 +607,6 @@ mod tests { fn test_tensor_size() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::zeros(&shape); - assert_eq!(tensor.size(), 6); } @@ -641,11 +632,10 @@ mod tests { tensor.set(&coord![0, 1].unwrap(), 6.0).unwrap(); tensor.set(&coord![1, 0].unwrap(), 7.0).unwrap(); tensor.set(&coord![1, 1].unwrap(), 8.0).unwrap(); - - assert_eq!(*tensor.get(&coord![0, 0].unwrap()).unwrap(), 5.0); - assert_eq!(*tensor.get(&coord![0, 1].unwrap()).unwrap(), 6.0); - assert_eq!(*tensor.get(&coord![1, 0].unwrap()).unwrap(), 7.0); - assert_eq!(*tensor.get(&coord![1, 1].unwrap()).unwrap(), 8.0); + assert_eq!( + tensor, + DynamicTensor::new(&shape, &[5.0, 6.0, 7.0, 8.0]).unwrap() + ); } #[test] @@ -677,9 +667,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![15.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap() + ); } #[test] @@ -689,9 +680,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![21.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap() + ); } #[test] @@ -703,9 +695,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![78.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[78.0]).unwrap() + ); } #[test] @@ -715,9 +708,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![15.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap() + ); } #[test] @@ -727,9 +721,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![5.0, 7.0, 9.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -741,11 +736,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); assert_eq!( - result.data, - DynamicStorage::new(vec![8.0, 10.0, 12.0, 14.0, 16.0, 18.0]) + result, + DynamicTensor::new(&shape![2, 3].unwrap(), &[8.0, 10.0, 12.0, 14.0, 16.0, 18.0]) + .unwrap() ); } @@ -756,9 +750,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0, 1]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![21.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap() + ); } #[test] @@ -770,9 +765,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![22.0, 26.0, 30.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[22.0, 26.0, 30.0]).unwrap() + ); } #[test] @@ -782,9 +778,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![3.5])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap() + ); } #[test] @@ -794,9 +791,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![2.5, 3.5, 4.5])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[2.5, 3.5, 4.5]).unwrap() + ); } #[test] @@ -808,11 +806,9 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0]); - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); assert_eq!( - result.data, - DynamicStorage::new(vec![4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + result, + DynamicTensor::new(&shape![2, 3].unwrap(), &[4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap() ); } @@ -823,9 +819,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0, 1]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![3.5])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap() + ); } #[test] @@ -837,9 +834,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![5.5, 6.5, 7.5])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[5.5, 6.5, 7.5]).unwrap() + ); } #[test] @@ -849,9 +847,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![9.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap() + ); } #[test] @@ -861,9 +860,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![2.25, 2.25, 2.25])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[2.25; 3]).unwrap() + ); } #[test] @@ -875,11 +875,9 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0]); - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); assert_eq!( - result.data, - DynamicStorage::new(vec![9.0, 9.0, 9.0, 9.0, 9.0, 9.0]) + result, + DynamicTensor::new(&shape![2, 3].unwrap(), &[9.0; 6]).unwrap() ); } @@ -890,9 +888,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0, 1]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![9.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap() + ); } #[test] @@ -904,9 +903,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![45.0, 45.0, 45.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[45.0; 3]).unwrap() + ); } #[test] @@ -916,9 +916,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![5.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[5.0]).unwrap() + ); } #[test] @@ -928,9 +929,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![1.0, 5.0, 3.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[1.0, 5.0, 3.0]).unwrap() + ); } #[test] @@ -942,9 +944,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![7.0, 11.0, 9.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[7.0, 11.0, 9.0]).unwrap() + ); } #[test] @@ -954,9 +957,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![-4.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[-4.0]).unwrap() + ); } #[test] @@ -966,9 +970,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![-4.0, -2.0, -6.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[-4.0, -2.0, -6.0]).unwrap() + ); } #[test] @@ -980,9 +985,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![-10.0, -8.0, -12.0])); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[-10.0, -8.0, -12.0]).unwrap() + ); } #[test] @@ -996,11 +1002,9 @@ mod tests { let tensor2 = Tensor::new(&shape2, &data2).unwrap(); let result = tensor1.prod(&tensor2); - - assert_eq!(result.shape(), &shape![3, 2].unwrap()); assert_eq!( - result.data, - DynamicStorage::new(vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0]) + result, + DynamicTensor::new(&shape![3, 2].unwrap(), &[4.0, 5.0, 8.0, 10.0, 12.0, 15.0]).unwrap() ); } @@ -1015,33 +1019,35 @@ mod tests { let tensor2 = Tensor::new(&shape2, &data2).unwrap(); let result = tensor1.prod(&tensor2); - - assert_eq!(result.shape(), &shape![2, 2, 2].unwrap()); assert_eq!( - result.data, - DynamicStorage::new(vec![5.0, 6.0, 10.0, 12.0, 15.0, 18.0, 20.0, 24.0]) + result, + DynamicTensor::new( + &shape![2, 2, 2].unwrap(), + &[5.0, 6.0, 10.0, 12.0, 15.0, 18.0, 20.0, 24.0] + ) + .unwrap() ); } #[test] fn test_tensor_prod_2d_2d() { - let shape1 = shape![2, 2].unwrap(); + let shape = shape![2, 2].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; - let tensor1 = Tensor::new(&shape1, &data1).unwrap(); - - let shape2 = shape![2, 2].unwrap(); + let tensor1 = Tensor::new(&shape, &data1).unwrap(); let data2 = vec![5.0, 6.0, 7.0, 8.0]; - let tensor2 = Tensor::new(&shape2, &data2).unwrap(); + let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1.prod(&tensor2); - - assert_eq!(result.shape(), &shape![2, 2, 2, 2].unwrap()); assert_eq!( - result.data, - DynamicStorage::new(vec![ - 5.0, 6.0, 7.0, 8.0, 10.0, 12.0, 14.0, 16.0, 15.0, 18.0, 21.0, 24.0, 20.0, 24.0, - 28.0, 32.0 - ]) + result, + DynamicTensor::new( + &shape![2, 2, 2, 2].unwrap(), + &[ + 5.0, 6.0, 7.0, 8.0, 10.0, 12.0, 14.0, 16.0, 15.0, 18.0, 21.0, 24.0, 20.0, 24.0, + 28.0, 32.0 + ] + ) + .unwrap() ); } @@ -1052,9 +1058,10 @@ mod tests { let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 + 3.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![4.0, 5.0, 6.0, 7.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[4.0, 5.0, 6.0, 7.0]).unwrap() + ); } #[test] @@ -1066,22 +1073,23 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 + tensor2; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![6.0, 8.0, 10.0, 12.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[6.0, 8.0, 10.0, 12.0]).unwrap() + ); } #[test] fn test_sub_tensor() { let shape = shape![4].unwrap(); let data1 = vec![5.0, 6.0, 7.0, 8.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 - 3.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 3.0, 4.0, 5.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap() + ); } #[test] @@ -1093,50 +1101,43 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 - tensor2; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![4.0, 4.0, 4.0, 4.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[4.0; 4]).unwrap()); } #[test] fn test_mul_tensor() { let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 * 2.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 4.0, 6.0, 8.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap() + ); } #[test] fn test_div_tensor() { let shape = shape![4].unwrap(); let data1 = vec![4.0, 6.0, 8.0, 10.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 / 2.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 3.0, 4.0, 5.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap() + ); } #[test] fn test_vec_vec_mul_single() { let shape = shape![1].unwrap(); - let data1 = vec![2.0]; - let data2 = vec![5.0]; - - let tensor1 = Tensor::new(&shape, &data1).unwrap(); - let tensor2 = Tensor::new(&shape, &data2).unwrap(); + let tensor1 = Tensor::new(&shape, &[2.0]).unwrap(); + let tensor2 = Tensor::new(&shape, &[5.0]).unwrap(); let result = tensor1 * tensor2; - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![10.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[10.0]).unwrap()); } #[test] @@ -1144,32 +1145,28 @@ mod tests { let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 * tensor2; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 6.0, 12.0, 20.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] fn test_matrix_matrix_mul() { - let shape1 = shape![2, 3].unwrap(); - let shape2 = shape![2, 3].unwrap(); + let shape = shape![2, 3].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let data2 = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; - - let tensor1 = Tensor::new(&shape1, &data1).unwrap(); - let tensor2 = Tensor::new(&shape2, &data2).unwrap(); + let tensor1 = Tensor::new(&shape, &data1).unwrap(); + let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 * tensor2; - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); assert_eq!( - result.data, - DynamicStorage::new(vec![7.0, 16.0, 27.0, 40.0, 55.0, 72.0]) + result, + DynamicTensor::new(&shape, &[7.0, 16.0, 27.0, 40.0, 55.0, 72.0]).unwrap() ); } @@ -1180,12 +1177,9 @@ mod tests { let data2 = vec![2.0, 3.0, 4.0, 5.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor + vector; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 5.0); - assert_eq!(result[2], 7.0); - assert_eq!(result[3], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -1195,12 +1189,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor - vector; - assert_eq!(result[0], 1.0); - assert_eq!(result[1], 1.0); - assert_eq!(result[2], 1.0); - assert_eq!(result[3], 1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[1.0; 4]).unwrap()); } #[test] @@ -1210,12 +1201,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor * vector; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 6.0); - assert_eq!(result[2], 12.0); - assert_eq!(result[3], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] @@ -1225,12 +1213,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor / vector; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 2.0); - assert_eq!(result[2], 2.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0; 4]).unwrap()); } #[test] @@ -1240,12 +1225,12 @@ mod tests { let data2 = vec![2.0, 3.0, 4.0, 5.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor + matrix; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 5.0); - assert_eq!(result[coord![1, 0].unwrap()], 7.0); - assert_eq!(result[coord![1, 1].unwrap()], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -1255,12 +1240,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor - matrix; - assert_eq!(result[coord![0, 0].unwrap()], 1.0); - assert_eq!(result[coord![0, 1].unwrap()], 1.0); - assert_eq!(result[coord![1, 0].unwrap()], 1.0); - assert_eq!(result[coord![1, 1].unwrap()], 1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[1.0; 4]).unwrap()); } #[test] @@ -1270,12 +1252,12 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor * matrix; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 6.0); - assert_eq!(result[coord![1, 0].unwrap()], 12.0); - assert_eq!(result[coord![1, 1].unwrap()], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] @@ -1285,12 +1267,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor / matrix; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 2.0); - assert_eq!(result[coord![1, 0].unwrap()], 2.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0; 4]).unwrap()); } #[test] @@ -1337,8 +1316,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(2.0); - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![1.0, 4.0, 9.0, 16.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[1.0, 4.0, 9.0, 16.0]).unwrap() + ); } #[test] @@ -1347,8 +1328,10 @@ mod tests { let data = vec![1.0, 4.0, 9.0, 16.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(0.5); - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![1.0, 2.0, 3.0, 4.0])); + assert_eq!( + result, + DynamicTensor::new(&shape, &[1.0, 2.0, 3.0, 4.0]).unwrap() + ); } #[test] @@ -1357,10 +1340,9 @@ mod tests { let data = vec![1.0, 2.0, 4.0, 8.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(-1.0); - assert_eq!(result.shape(), &shape); assert_eq!( - result.data, - DynamicStorage::new(vec![1.0, 0.5, 0.25, 0.125]) + result, + DynamicTensor::new(&shape, &[1.0, 0.5, 0.25, 0.125]).unwrap() ); } } diff --git a/src/vector.rs b/src/vector.rs index ff2b939..200f97c 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -9,6 +9,7 @@ use crate::tensor::DynamicTensor; use num::Float; use num::Num; +#[derive(Debug, PartialEq)] pub struct DynamicVector { tensor: DynamicTensor, } @@ -251,11 +252,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data).unwrap(); let vector = DynamicVector::from_tensor(tensor).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 1.0); - assert_eq!(vector[1], 2.0); - assert_eq!(vector[2], 3.0); - assert_eq!(vector[3], 4.0); + assert_eq!(vector, DynamicVector::new(&data).unwrap()); } #[test] @@ -271,33 +268,21 @@ mod tests { fn test_fill() { let shape = shape![4].unwrap(); let vector = DynamicVector::fill(&shape, 3.0).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 3.0); - assert_eq!(vector[1], 3.0); - assert_eq!(vector[2], 3.0); - assert_eq!(vector[3], 3.0); + assert_eq!(vector, DynamicVector::new(&[3.0; 4]).unwrap()); } #[test] fn test_zeros() { let shape = shape![4].unwrap(); let vector = DynamicVector::::zeros(&shape).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 0.0); - assert_eq!(vector[1], 0.0); - assert_eq!(vector[2], 0.0); - assert_eq!(vector[3], 0.0); + assert_eq!(vector, DynamicVector::new(&[0.0; 4]).unwrap()); } #[test] fn test_ones() { let shape = shape![4].unwrap(); let vector = DynamicVector::::ones(&shape).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 1.0); - assert_eq!(vector[1], 1.0); - assert_eq!(vector[2], 1.0); - assert_eq!(vector[3], 1.0); + assert_eq!(vector, DynamicVector::new(&[1.0; 4]).unwrap()); } #[test] @@ -335,8 +320,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.sum(); - assert_eq!(result[0], 10.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[10.0]).unwrap()); } #[test] @@ -344,8 +328,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.mean(); - assert_eq!(result[0], 2.5); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[2.5]).unwrap()); } #[test] @@ -353,8 +336,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.var(); - assert_eq!(result[0], 1.25); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.25]).unwrap()); } #[test] @@ -362,8 +344,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.min(); - assert_eq!(result[0], 1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.0]).unwrap()); } #[test] @@ -371,8 +352,7 @@ mod tests { let data = vec![-1.0, -2.0, -3.0, -4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.max(); - assert_eq!(result[0], -1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[-1.0]).unwrap()); } #[test] @@ -382,8 +362,7 @@ mod tests { let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1.vecmul(&vector2); - assert_eq!(result[0], 40.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[40.0]).unwrap()); } #[test] @@ -393,9 +372,7 @@ mod tests { let vector = DynamicVector::new(&data_vector).unwrap(); let matrix = DynamicMatrix::new(&shape![2, 2].unwrap(), &data_matrix).unwrap(); let result = vector.matmul(&matrix); - assert_eq!(result.shape(), &shape![2].unwrap()); - assert_eq!(result[0], 7.0); - assert_eq!(result[1], 10.0); + assert_eq!(result, DynamicVector::new(&[7.0, 10.0]).unwrap()); } #[test] @@ -406,20 +383,15 @@ mod tests { let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1.prod(&vector2); - let expected_data = vec![ - 2.0, 3.0, 4.0, 5.0, 4.0, 6.0, 8.0, 10.0, 6.0, 9.0, 12.0, 15.0, 8.0, 12.0, 16.0, 20.0, - ]; - let expected_shape = shape![4, 4].unwrap(); - let expected_tensor = DynamicTensor::new(&expected_shape, &expected_data).unwrap(); - - assert_eq!(result.shape(), &expected_shape); - for i in 0..result.shape()[0] { - for j in 0..result.shape()[1] { - let x = result.get(&coord![i, j].unwrap()).unwrap(); - let y = expected_tensor.get(&coord![i, j].unwrap()).unwrap(); - assert_eq!(*x, *y); - } - } + let expected_tensor = DynamicTensor::new( + &shape![4, 4].unwrap(), + &[ + 2.0, 3.0, 4.0, 5.0, 4.0, 6.0, 8.0, 10.0, 6.0, 9.0, 12.0, 15.0, 8.0, 12.0, 16.0, + 20.0, + ], + ) + .unwrap(); + assert_eq!(result, expected_tensor); } #[test] @@ -427,26 +399,17 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector + 2.0; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 4.0); - assert_eq!(result[2], 5.0); - assert_eq!(result[3], 6.0); - assert_eq!(result.shape(), &shape![4].unwrap()); + assert_eq!(result, DynamicVector::new(&[3.0, 4.0, 5.0, 6.0]).unwrap()); } #[test] fn test_add_vector() { - let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 + vector2; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 5.0); - assert_eq!(result[2], 7.0); - assert_eq!(result[3], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -457,39 +420,25 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector + tensor; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 5.0); - assert_eq!(result[2], 7.0); - assert_eq!(result[3], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] fn test_sub_scalar() { - let shape = shape![4].unwrap(); let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector - 2.0; - assert_eq!(result[0], -1.0); - assert_eq!(result[1], 0.0); - assert_eq!(result[2], 1.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[-1.0, 0.0, 1.0, 2.0]).unwrap()); } #[test] fn test_sub_vector() { - let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 - vector2; - assert_eq!(result[0], -1.0); - assert_eq!(result[1], -1.0); - assert_eq!(result[2], -1.0); - assert_eq!(result[3], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[-1.0; 4]).unwrap()); } #[test] @@ -500,39 +449,25 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector - tensor; - assert_eq!(result[0], -1.0); - assert_eq!(result[1], -1.0); - assert_eq!(result[2], -1.0); - assert_eq!(result[3], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[-1.0; 4]).unwrap()); } #[test] fn test_mul_scalar() { - let shape = shape![4].unwrap(); let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector * 2.0; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 4.0); - assert_eq!(result[2], 6.0); - assert_eq!(result[3], 8.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 4.0, 6.0, 8.0]).unwrap()); } #[test] fn test_mul_vector() { - let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 * vector2; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 6.0); - assert_eq!(result[2], 12.0); - assert_eq!(result[3], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] @@ -543,39 +478,25 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector * tensor; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 6.0); - assert_eq!(result[2], 12.0); - assert_eq!(result[3], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] fn test_div_scalar() { - let shape = shape![4].unwrap(); let data = vec![4.0, 6.0, 8.0, 10.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector / 2.0; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 3.0); - assert_eq!(result[2], 4.0); - assert_eq!(result[3], 5.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 3.0, 4.0, 5.0]).unwrap()); } #[test] fn test_div_vector() { - let shape = shape![4].unwrap(); let data1 = vec![4.0, 6.0, 8.0, 10.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 / vector2; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 2.0); - assert_eq!(result[2], 2.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0; 4]).unwrap()); } #[test] @@ -586,23 +507,14 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector / tensor; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 2.0); - assert_eq!(result[2], 2.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0; 4]).unwrap()); } #[test] fn test_pow_vector() { - let shape = shape![4].unwrap(); let data = vec![2.0, 3.0, 4.0, 5.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.pow(2.0); - assert_eq!(result[0], 4.0); - assert_eq!(result[1], 9.0); - assert_eq!(result[2], 16.0); - assert_eq!(result[3], 25.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[4.0, 9.0, 16.0, 25.0]).unwrap()); } }