|
87 | 87 | #include <errno.h> |
88 | 88 | #include <executorch/extension/data_loader/buffer_data_loader.h> |
89 | 89 | #include <executorch/extension/runner_util/inputs.h> |
| 90 | +#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
90 | 91 | #include <executorch/runtime/core/memory_allocator.h> |
91 | 92 | #include <executorch/runtime/executor/program.h> |
92 | 93 | #include <executorch/runtime/platform/log.h> |
|
95 | 96 | #include <stdio.h> |
96 | 97 | #include <unistd.h> |
97 | 98 | #include <memory> |
| 99 | +#include <type_traits> |
98 | 100 | #include <vector> |
99 | 101 |
|
100 | 102 | #include "arm_memory_allocator.h" |
@@ -183,6 +185,7 @@ using executorch::runtime::Result; |
183 | 185 | using executorch::runtime::Span; |
184 | 186 | using executorch::runtime::Tag; |
185 | 187 | using executorch::runtime::TensorInfo; |
| 188 | +using executorch::runtime::toString; |
186 | 189 | #if defined(ET_BUNDLE_IO) |
187 | 190 | using executorch::bundled_program::compute_method_output_error_stats; |
188 | 191 | using executorch::bundled_program::ErrorStats; |
@@ -395,6 +398,19 @@ class Box { |
395 | 398 | } |
396 | 399 | }; |
397 | 400 |
|
| 401 | +template <typename ValueType> |
| 402 | +void fill_tensor_with_default_value(Tensor& tensor) { |
| 403 | + ValueType fill_value{}; |
| 404 | + if constexpr (std::is_same_v<ValueType, bool>) { |
| 405 | + fill_value = true; |
| 406 | + } else { |
| 407 | + fill_value = ValueType(1); |
| 408 | + } |
| 409 | + |
| 410 | + ValueType* data_ptr = tensor.mutable_data_ptr<ValueType>(); |
| 411 | + std::fill(data_ptr, data_ptr + tensor.numel(), fill_value); |
| 412 | +} |
| 413 | + |
398 | 414 | Error prepare_input_tensors( |
399 | 415 | Method& method, |
400 | 416 | MemoryAllocator& allocator, |
@@ -452,32 +468,17 @@ Error prepare_input_tensors( |
452 | 468 | if (input_evalues[i].isTensor()) { |
453 | 469 | Tensor& tensor = input_evalues[i].toTensor(); |
454 | 470 | switch (tensor.scalar_type()) { |
455 | | - case ScalarType::Int: |
456 | | - std::fill( |
457 | | - tensor.mutable_data_ptr<int>(), |
458 | | - tensor.mutable_data_ptr<int>() + tensor.numel(), |
459 | | - 1); |
460 | | - break; |
461 | | - case ScalarType::Float: |
462 | | - std::fill( |
463 | | - tensor.mutable_data_ptr<float>(), |
464 | | - tensor.mutable_data_ptr<float>() + tensor.numel(), |
465 | | - 1.0); |
466 | | - break; |
467 | | - case ScalarType::Char: |
468 | | - std::fill( |
469 | | - tensor.mutable_data_ptr<int8_t>(), |
470 | | - tensor.mutable_data_ptr<int8_t>() + tensor.numel(), |
471 | | - 1); |
472 | | - break; |
473 | | - case ScalarType::Bool: |
474 | | - std::fill( |
475 | | - tensor.mutable_data_ptr<int8_t>(), |
476 | | - tensor.mutable_data_ptr<int8_t>() + tensor.numel(), |
477 | | - 1); |
478 | | - break; |
| 471 | +#define HANDLE_SCALAR_TYPE(cpp_type, scalar_name) \ |
| 472 | + case ScalarType::scalar_name: \ |
| 473 | + fill_tensor_with_default_value<cpp_type>(tensor); \ |
| 474 | + break; |
| 475 | + ET_FORALL_SCALAR_TYPES(HANDLE_SCALAR_TYPE) |
| 476 | +#undef HANDLE_SCALAR_TYPE |
479 | 477 | default: |
480 | | - ET_LOG(Error, "Unhandled ScalarType"); |
| 478 | + ET_LOG( |
| 479 | + Error, |
| 480 | + "Unhandled ScalarType %s", |
| 481 | + toString(tensor.scalar_type())); |
481 | 482 | err = Error::InvalidArgument; |
482 | 483 | break; |
483 | 484 | } |
|
0 commit comments