diff --git a/include/gridtools/fn/backend/common.hpp b/include/gridtools/fn/backend/common.hpp index 2626eb718..6616ee9ee 100644 --- a/include/gridtools/fn/backend/common.hpp +++ b/include/gridtools/fn/backend/common.hpp @@ -31,10 +31,30 @@ namespace gridtools::fn::backend { meta::rename()); } + template + constexpr GT_FUNCTION auto make_unrolled_loops(Sizes const &sizes, UnrollFactors) { + return tuple_util::host_device::fold( + [&](auto outer, auto dim) { + using unroll_factor = std::remove_reference_t( + std::declval()))>; + return [outer = std::move(outer), + inner = sid::make_unrolled_loop( + host_device::at_key(sizes))]( + auto &&...args) { return outer(inner(std::forward(args)...)); }; + }, + host_device::identity(), + meta::rename()); + } + template constexpr GT_FUNCTION auto make_loops(Sizes const &sizes) { return make_loops>(sizes); } + + template + constexpr GT_FUNCTION auto make_unrolled_loops(Sizes const &sizes, UnrollFactors unroll_factors) { + return make_unrolled_loops>(sizes, unroll_factors); + } } // namespace common template diff --git a/include/gridtools/fn/backend/gpu.hpp b/include/gridtools/fn/backend/gpu.hpp index 5f1500b70..c771941ea 100644 --- a/include/gridtools/fn/backend/gpu.hpp +++ b/include/gridtools/fn/backend/gpu.hpp @@ -9,6 +9,7 @@ */ #pragma once +#include #include #include "../../common/cuda_util.hpp" @@ -23,108 +24,154 @@ namespace gridtools::fn::backend { namespace gpu_impl_ { + template + struct is_valid_block_size_key_value_pair : std::false_type {}; + + template