Skip to content

Commit b8cd2d7

Browse files
committed
Arm backend: Handle all types in prepare_input_tensors()
When there are no input values we want to fill it with ones for all scalartypes. Signed-off-by: per.held@arm.com Change-Id: I8fd80eac755305a9cd2d304e5f3f932cf3536557
1 parent d886373 commit b8cd2d7

File tree

1 file changed

+26
-25
lines changed

1 file changed

+26
-25
lines changed

examples/arm/executor_runner/arm_executor_runner.cpp

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
#include <errno.h>
8888
#include <executorch/extension/data_loader/buffer_data_loader.h>
8989
#include <executorch/extension/runner_util/inputs.h>
90+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
9091
#include <executorch/runtime/core/memory_allocator.h>
9192
#include <executorch/runtime/executor/program.h>
9293
#include <executorch/runtime/platform/log.h>
@@ -95,6 +96,7 @@
9596
#include <stdio.h>
9697
#include <unistd.h>
9798
#include <memory>
99+
#include <type_traits>
98100
#include <vector>
99101

100102
#include "arm_memory_allocator.h"
@@ -183,6 +185,7 @@ using executorch::runtime::Result;
183185
using executorch::runtime::Span;
184186
using executorch::runtime::Tag;
185187
using executorch::runtime::TensorInfo;
188+
using executorch::runtime::toString;
186189
#if defined(ET_BUNDLE_IO)
187190
using executorch::bundled_program::compute_method_output_error_stats;
188191
using executorch::bundled_program::ErrorStats;
@@ -395,6 +398,19 @@ class Box {
395398
}
396399
};
397400

401+
template <typename ValueType>
402+
void fill_tensor_with_default_value(Tensor& tensor) {
403+
ValueType fill_value{};
404+
if constexpr (std::is_same_v<ValueType, bool>) {
405+
fill_value = true;
406+
} else {
407+
fill_value = ValueType(1);
408+
}
409+
410+
ValueType* data_ptr = tensor.mutable_data_ptr<ValueType>();
411+
std::fill(data_ptr, data_ptr + tensor.numel(), fill_value);
412+
}
413+
398414
Error prepare_input_tensors(
399415
Method& method,
400416
MemoryAllocator& allocator,
@@ -452,32 +468,17 @@ Error prepare_input_tensors(
452468
if (input_evalues[i].isTensor()) {
453469
Tensor& tensor = input_evalues[i].toTensor();
454470
switch (tensor.scalar_type()) {
455-
case ScalarType::Int:
456-
std::fill(
457-
tensor.mutable_data_ptr<int>(),
458-
tensor.mutable_data_ptr<int>() + tensor.numel(),
459-
1);
460-
break;
461-
case ScalarType::Float:
462-
std::fill(
463-
tensor.mutable_data_ptr<float>(),
464-
tensor.mutable_data_ptr<float>() + tensor.numel(),
465-
1.0);
466-
break;
467-
case ScalarType::Char:
468-
std::fill(
469-
tensor.mutable_data_ptr<int8_t>(),
470-
tensor.mutable_data_ptr<int8_t>() + tensor.numel(),
471-
1);
472-
break;
473-
case ScalarType::Bool:
474-
std::fill(
475-
tensor.mutable_data_ptr<int8_t>(),
476-
tensor.mutable_data_ptr<int8_t>() + tensor.numel(),
477-
1);
478-
break;
471+
#define HANDLE_SCALAR_TYPE(cpp_type, scalar_name) \
472+
case ScalarType::scalar_name: \
473+
fill_tensor_with_default_value<cpp_type>(tensor); \
474+
break;
475+
ET_FORALL_SCALAR_TYPES(HANDLE_SCALAR_TYPE)
476+
#undef HANDLE_SCALAR_TYPE
479477
default:
480-
ET_LOG(Error, "Unhandled ScalarType");
478+
ET_LOG(
479+
Error,
480+
"Unhandled ScalarType %s",
481+
toString(tensor.scalar_type()));
481482
err = Error::InvalidArgument;
482483
break;
483484
}

0 commit comments

Comments
 (0)