diff --git a/src/matrix.rs b/src/matrix.rs index 215be90..6d9d6c6 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -47,6 +47,24 @@ impl DynamicMatrix { Self::fill(shape, T::one()) } + pub fn redimension(&self, shape: &Shape) -> Result, ShapeError> { + if shape.order() != 2 { + return Err(ShapeError::new("Shape must have order of 2")); + } + let result = self.tensor.reshape(shape)?; + Ok(DynamicMatrix { tensor: result }) + } + + pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + self.tensor.reshape(shape) + } + + pub fn flatten(&self) -> DynamicVector { + let flattened_shape = Shape::new(vec![self.tensor.size()]).unwrap(); + let result = self.tensor.reshape(&flattened_shape).unwrap(); + DynamicVector::from_tensor(result).unwrap() + } + pub fn sum(&self, axes: Axes) -> DynamicVector { let result = self.tensor.sum(axes); DynamicVector::from_tensor(result).unwrap() @@ -328,6 +346,43 @@ mod tests { assert_eq!(matrix[coord![1, 1].unwrap()], 1.0); } + #[test] + fn test_reshape() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let matrix = DynamicMatrix::new(&shape, &data).unwrap(); + let new_shape = shape![4, 1].unwrap(); + let reshaped_matrix = matrix.redimension(&new_shape).unwrap(); + assert_eq!(reshaped_matrix.shape(), &new_shape); + assert_eq!(reshaped_matrix[coord![0, 0].unwrap()], 1.0); + assert_eq!(reshaped_matrix[coord![1, 0].unwrap()], 2.0); + assert_eq!(reshaped_matrix[coord![2, 0].unwrap()], 3.0); + assert_eq!(reshaped_matrix[coord![3, 0].unwrap()], 4.0); + } + + #[test] + fn test_reshape_fail() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let matrix = DynamicMatrix::new(&shape, &data).unwrap(); + let new_shape = shape![3, 2].unwrap(); + let result = matrix.reshape(&new_shape); + assert!(result.is_err()); + } + + #[test] + fn test_flatten() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let matrix = DynamicMatrix::new(&shape, &data).unwrap(); + let flattened_vector = matrix.flatten(); + assert_eq!(flattened_vector.shape(), &shape![4].unwrap()); + assert_eq!(flattened_vector[0], 1.0); + assert_eq!(flattened_vector[1], 2.0); + assert_eq!(flattened_vector[2], 3.0); + assert_eq!(flattened_vector[3], 4.0); + } + #[test] fn test_size() { let shape = shape![2, 2].unwrap(); diff --git a/src/storage.rs b/src/storage.rs index b51fbdf..c5b58f6 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,7 +4,7 @@ use crate::coordinate::Coordinate; use crate::error::ShapeError; use crate::shape::Shape; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub struct DynamicStorage { data: Vec, } diff --git a/src/tensor.rs b/src/tensor.rs index 595dc26..480c375 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -43,7 +43,21 @@ impl Tensor { Tensor::fill(shape, T::one()) } + pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + if self.shape.size() != shape.size() { + return Err(ShapeError::new("Data length does not match shape size")); + } + Ok(Tensor { + data: self.data.clone(), + shape: shape.clone(), + }) + } + // Properties + pub fn raw(&self) -> &DynamicStorage { + &self.data + } + pub fn shape(&self) -> &Shape { &self.shape } @@ -51,6 +65,7 @@ impl Tensor { self.shape.size() } + // Access methods pub fn get(&self, coord: &Coordinate) -> Result<&T, ShapeError> { Ok(&self.data[self.data.flatten(coord, &self.shape)?]) } @@ -66,7 +81,7 @@ impl Tensor { Ok(()) } - // // Reduction operations + // Reduction operations pub fn sum(&self, axes: Axes) -> Tensor { let all_axes = (0..self.shape.order()).collect::>(); let remaining_axes = all_axes @@ -594,6 +609,31 @@ mod tests { assert_eq!(tensor.data, DynamicStorage::new(vec![1.0; shape.size()])); } + #[test] + fn test_reshape_tensor() { + let shape = shape![2, 3].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + + let new_shape = shape![3, 2].unwrap(); + let reshaped_tensor = tensor.reshape(&new_shape).unwrap(); + + assert_eq!(reshaped_tensor.shape(), &new_shape); + assert_eq!(reshaped_tensor.data, DynamicStorage::new(data)); + } + + #[test] + fn test_reshape_tensor_shape_mismatch() { + let shape = shape![2, 3].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + + let new_shape = shape![4, 2].unwrap(); + let result = tensor.reshape(&new_shape); + + assert!(result.is_err()); + } + #[test] fn test_fill_tensor() { let shape = shape![2, 3].unwrap(); diff --git a/src/vector.rs b/src/vector.rs index ff2b939..a6c8c69 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -42,6 +42,10 @@ impl DynamicVector { Self::fill(shape, T::one()) } + pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + self.tensor.reshape(shape) + } + pub fn sum(&self) -> DynamicVector { let result = self.tensor.sum(vec![]); DynamicVector::from_tensor(result).unwrap()