From b344a69d6c86a06a15cb80b2aae0398e46893383 Mon Sep 17 00:00:00 2001 From: Per Held Date: Wed, 3 Dec 2025 09:51:44 +0100 Subject: [PATCH] Arm backend: Handle all types in prepare_input_tensors() When there are no input values we want to fill it with ones for all scalartypes. Signed-off-by: per.held@arm.com Change-Id: I8fd80eac755305a9cd2d304e5f3f932cf3536557 --- .../executor_runner/arm_executor_runner.cpp | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/examples/arm/executor_runner/arm_executor_runner.cpp b/examples/arm/executor_runner/arm_executor_runner.cpp index 87d9026de3f..89ebcd292f7 100644 --- a/examples/arm/executor_runner/arm_executor_runner.cpp +++ b/examples/arm/executor_runner/arm_executor_runner.cpp @@ -87,6 +87,7 @@ #include #include #include +#include #include #include #include @@ -95,6 +96,7 @@ #include #include #include +#include #include #include "arm_memory_allocator.h" @@ -183,6 +185,7 @@ using executorch::runtime::Result; using executorch::runtime::Span; using executorch::runtime::Tag; using executorch::runtime::TensorInfo; +using executorch::runtime::toString; #if defined(ET_BUNDLE_IO) using executorch::bundled_program::compute_method_output_error_stats; using executorch::bundled_program::ErrorStats; @@ -395,6 +398,19 @@ class Box { } }; +template +void fill_tensor_with_default_value(Tensor& tensor) { + ValueType fill_value{}; + if constexpr (std::is_same_v) { + fill_value = true; + } else { + fill_value = ValueType(1); + } + + ValueType* data_ptr = tensor.mutable_data_ptr(); + std::fill(data_ptr, data_ptr + tensor.numel(), fill_value); +} + Error prepare_input_tensors( Method& method, MemoryAllocator& allocator, @@ -452,32 +468,17 @@ Error prepare_input_tensors( if (input_evalues[i].isTensor()) { Tensor& tensor = input_evalues[i].toTensor(); switch (tensor.scalar_type()) { - case ScalarType::Int: - std::fill( - tensor.mutable_data_ptr(), - tensor.mutable_data_ptr() + tensor.numel(), - 1); - break; - case ScalarType::Float: - std::fill( - tensor.mutable_data_ptr(), - tensor.mutable_data_ptr() + tensor.numel(), - 1.0); - break; - case ScalarType::Char: - std::fill( - tensor.mutable_data_ptr(), - tensor.mutable_data_ptr() + tensor.numel(), - 1); - break; - case ScalarType::Bool: - std::fill( - tensor.mutable_data_ptr(), - tensor.mutable_data_ptr() + tensor.numel(), - 1); - break; +#define HANDLE_SCALAR_TYPE(cpp_type, scalar_name) \ + case ScalarType::scalar_name: \ + fill_tensor_with_default_value(tensor); \ + break; + ET_FORALL_SCALAR_TYPES(HANDLE_SCALAR_TYPE) +#undef HANDLE_SCALAR_TYPE default: - ET_LOG(Error, "Unhandled ScalarType"); + ET_LOG( + Error, + "Unhandled ScalarType %s", + toString(tensor.scalar_type())); err = Error::InvalidArgument; break; }