-
Notifications
You must be signed in to change notification settings - Fork 27
Description
Discussion on PR #197 and elsewhere (e.g., with @youyu3 ) shows that it's tricky to optimize expressions like transposed(conjugated(scaled(alpha, A))) that result from calling e.g., matrix_product. "Optimize" here means "deduce that we can call an optimized BLAS routine." For layout_left A, the Fortran BLAS can handle this case directly, by setting TRANSA='C' and ALPHA=alpha.
It occurred to me that a "recursive" design could make this easier. I put "recursive" in quotes because it's based on function overloads; the calls to the function with the same name aren't actually recursive, because their arguments' types change on each nested call.
Here's some pseudocode:
enum class ETrans { N, T, H, C };
template<std::semiregular Scalar, ETrans Trans>
struct Extracted {
static constexpr ETrans trans = Trans;
std::optional<Scalar> scalar;
};
template<std::semiregular Scalar, ETrans Trans>
auto toggle_transpose(Extracted<Scalar, Trans> e)
{
if constexpr (Trans == ETrans::N) {
return Extracted<Scalar, ETrans::T>{e.scalar};
} else if constexpr (Trans == ETrans::T) {
return Extracted<Scalar, ETrans::N>{e.scalar};
} else if constexpr (Trans == ETrans::H) {
return Extracted<Scalar, ETrans::C>{e.scalar};
} else { // ETrans::C
return Extracted<Scalar, ETrans::H>{e.scalar};
}
}
template<std::semiregular Scalar, ETrans Trans>
auto toggle_conjugate(Extracted<Scalar, Trans> e)
{
if constexpr (Trans == ETrans::N) {
return Extracted<Scalar, ETrans::C>{e.scalar};
} else if constexpr (Trans == ETrans::T) {
return Extracted<Scalar, ETrans::H>{e.scalar};
} else if constexpr (Trans == ETrans::H) {
return Extracted<Scalar, ETrans::T>{e.scalar};
} else { // ETrans::C
return Extracted<Scalar, ETrans::N>{e.scalar};
}
}
template<std::semiregular InputScalar, std::semiregular Scalar, ETrans Trans>
auto add_or_replace_scalar(InputScalar s, Extracted<Scalar, Trans> e)
{
return Extracted<InputScalar, ETrans>{s}; // discard current scalar in e
}
// omitting constraints on template parameters for brevity
template<class in_matrix_1_t, class Extracted1,
class in_matrix_2_t, class Extracted2,
class out_matrix_t,
class in_matrix_1_original_t,
class in_matrix_t_original_t>
void matrix_product_impl(
in_matrix_1_t A, Extracted1 A_data,
in_matrix_2_t B, Extracted2 B_data,
out_matrix_t C,
in_matrix_1_original_t A_original,
in_matrix_2_original_t B_original)
{
if constexpr (/* It's obvious we can't call the BLAS */) {
// Early exit from the "recursion" avoids penalizing the generic case with higher compile times.
matrix_product_fallback(A_original, B_original, C);
}
else if constexpr (/* A's outer layout is layout_transpose */) {
matrix_product_impl(strip_nested_mapping(A), toggle_transpose(A_data),
B, B_data, C, A_original, B_original);
}
else if constexpr (/* A's outer accessor is accessor_scaled */) {
if(A_data.scalar.has_value()) {
// ... check at compile time that it makes sense to multiply the two scaling factors, else fall back ...
matrix_product_impl(strip_nested_accessor(A),
add_or_replace_scalar(A.accessor().scaling_factor() * A.data.scalar.value(), A_data),
B, B_data, C, A_original, B_original);
} else {
matrix_product_impl(strip_nested_accessor(A), add_or_replace_scalar(A.accessor().scaling_factor(), A_data),
B, B_data, C, A_original, B_original);
}
}
else if constexpr (/* A's outer accessor is accessor_conjugate */) {
matrix_product_impl(strip_nested_accessor(A), toggle_conjugate(A_data),
B, B_data, C, A_original, B_original);
}
// ... repeat the above pattern for B ...
else if constexpr (/* all the types are BLAS friendly */) {
// ETrans is a template parameter of Extracted so we can check most BLAS compatibility at compile time.
// Extracted1 and Extracted2 are template parameters so that we don't force Scalar type conversion.
// Some mixed-precision BLAS implementations (e.g., cuGemmEx) permit the Scalar type
// to have a different type than the matrices' value types.
if (/* any run-time decision whether we can call BLAS, e.g., layout_stride run-time strides */) {
// ... call the BLAS using scalar and transpose from both Extracted structs ...
} else {
matrix_product_fallback(A_original, B_original, C);
}
}
else {
matrix_product_fallback(A_original, B_original, C);
}
}
void matrix_product(in_matrix_1_t A, in_matrix_2_t B, out_matrix_t C)
{
// "Recursive" calls may change the Extracted Scalar type.
matrix_product_impl(A, Extracted<typename in_matrix_1_t::value_type, ETrans::N>{},
B, Extracted<typename in_matrix_2_t::value_type, ETrans::N>{},
C, A, B);
}For a C++14 - compatible implementation, one could use function overloads (partial specialization) instead of if constexpr.
Here are some issues with the above approach.
- It has to construct
mdspanonce per "recursion" level. - It increases the function call depth, which may interfere with inlining in the fall-back case.
We can fix at least (2) by applying the recursive approach to each pair (A, A_data) and (B, B_data). This will bound the function call depth for the fall-back case.
template<class InMatrix, class ExtractedType>
auto extract(std::tuple<InMatrix, ExtractedType>);Regarding (1), we can mitigate this by limiting the "recursion" depth for cases that the BLAS obviously can't handle. Also, taking the mdspan by value lets us move-construct the pointer, layout, and accessor at each level, so we can reduce cost for the (admittedly unusual) case where any of these are expensive to construct.