Skip to content

Optimizing nested conjugated / transposed / scaled expressions #203

@mhoemmen

Description

@mhoemmen

Discussion on PR #197 and elsewhere (e.g., with @youyu3 ) shows that it's tricky to optimize expressions like transposed(conjugated(scaled(alpha, A))) that result from calling e.g., matrix_product. "Optimize" here means "deduce that we can call an optimized BLAS routine." For layout_left A, the Fortran BLAS can handle this case directly, by setting TRANSA='C' and ALPHA=alpha.

It occurred to me that a "recursive" design could make this easier. I put "recursive" in quotes because it's based on function overloads; the calls to the function with the same name aren't actually recursive, because their arguments' types change on each nested call.

Here's some pseudocode:

enum class ETrans { N, T, H, C };

template<std::semiregular Scalar, ETrans Trans>
struct Extracted {
    static constexpr ETrans trans = Trans;
    std::optional<Scalar> scalar;
};

template<std::semiregular Scalar, ETrans Trans>
auto toggle_transpose(Extracted<Scalar, Trans> e)
{
    if constexpr (Trans == ETrans::N) {
        return Extracted<Scalar, ETrans::T>{e.scalar};
    } else if constexpr (Trans == ETrans::T) {
        return Extracted<Scalar, ETrans::N>{e.scalar};
    } else if constexpr (Trans == ETrans::H) {
        return Extracted<Scalar, ETrans::C>{e.scalar};
    } else { // ETrans::C
        return Extracted<Scalar, ETrans::H>{e.scalar};
    }
}

template<std::semiregular Scalar, ETrans Trans>
auto toggle_conjugate(Extracted<Scalar, Trans> e)
{
    if constexpr (Trans == ETrans::N) {
        return Extracted<Scalar, ETrans::C>{e.scalar};
    } else if constexpr (Trans == ETrans::T) {
        return Extracted<Scalar, ETrans::H>{e.scalar};
    } else if constexpr (Trans == ETrans::H) {
        return Extracted<Scalar, ETrans::T>{e.scalar};
    } else { // ETrans::C
        return Extracted<Scalar, ETrans::N>{e.scalar};
    }
}

template<std::semiregular InputScalar, std::semiregular Scalar, ETrans Trans>
auto add_or_replace_scalar(InputScalar s, Extracted<Scalar, Trans> e)
{
    return Extracted<InputScalar, ETrans>{s}; // discard current scalar in e
}

// omitting constraints on template parameters for brevity
template<class in_matrix_1_t, class Extracted1,
    class in_matrix_2_t, class Extracted2,
    class out_matrix_t,
    class in_matrix_1_original_t,
    class in_matrix_t_original_t>
void matrix_product_impl(
    in_matrix_1_t A, Extracted1 A_data,
    in_matrix_2_t B, Extracted2 B_data,
    out_matrix_t C,
    in_matrix_1_original_t A_original,
    in_matrix_2_original_t B_original)
{
    if constexpr (/* It's obvious we can't call the BLAS */) {
        // Early exit from the "recursion" avoids penalizing the generic case with higher compile times.
        matrix_product_fallback(A_original, B_original, C);
    }
    else if constexpr (/* A's outer layout is layout_transpose */) {
        matrix_product_impl(strip_nested_mapping(A), toggle_transpose(A_data),
            B, B_data, C, A_original, B_original);
    }
    else if constexpr (/* A's outer accessor is accessor_scaled */) {
        if(A_data.scalar.has_value()) {
            // ... check at compile time that it makes sense to multiply the two scaling factors, else fall back ...
            matrix_product_impl(strip_nested_accessor(A),
                add_or_replace_scalar(A.accessor().scaling_factor() * A.data.scalar.value(), A_data),
                B, B_data, C, A_original, B_original);
        } else {
            matrix_product_impl(strip_nested_accessor(A), add_or_replace_scalar(A.accessor().scaling_factor(), A_data),
                B, B_data, C, A_original, B_original);
        }
    }
    else if constexpr (/* A's outer accessor is accessor_conjugate */) {
        matrix_product_impl(strip_nested_accessor(A), toggle_conjugate(A_data),
            B, B_data, C, A_original, B_original);
    }
    // ... repeat the above pattern for B ...
    else if constexpr (/* all the types are BLAS friendly */) {
        // ETrans is a template parameter of Extracted so we can check most BLAS compatibility at compile time.
        // Extracted1 and Extracted2 are template parameters so that we don't force Scalar type conversion.
        // Some mixed-precision BLAS implementations (e.g., cuGemmEx) permit the Scalar type
        // to have a different type than the matrices' value types.

        if (/* any run-time decision whether we can call BLAS, e.g., layout_stride run-time strides */) {
            // ... call the BLAS using scalar and transpose from both Extracted structs ...
        } else {
            matrix_product_fallback(A_original, B_original, C);
        }
    }
    else {
        matrix_product_fallback(A_original, B_original, C);
    }
}

void matrix_product(in_matrix_1_t A, in_matrix_2_t B, out_matrix_t C)
{
    // "Recursive" calls may change the Extracted Scalar type.
    matrix_product_impl(A, Extracted<typename in_matrix_1_t::value_type, ETrans::N>{},
        B, Extracted<typename in_matrix_2_t::value_type, ETrans::N>{},
        C, A, B);
}

For a C++14 - compatible implementation, one could use function overloads (partial specialization) instead of if constexpr.

Here are some issues with the above approach.

  1. It has to construct mdspan once per "recursion" level.
  2. It increases the function call depth, which may interfere with inlining in the fall-back case.

We can fix at least (2) by applying the recursive approach to each pair (A, A_data) and (B, B_data). This will bound the function call depth for the fall-back case.

template<class InMatrix, class ExtractedType>
auto extract(std::tuple<InMatrix, ExtractedType>);

Regarding (1), we can mitigate this by limiting the "recursion" depth for cases that the BLAS obviously can't handle. Also, taking the mdspan by value lets us move-construct the pointer, layout, and accessor at each level, so we can reduce cost for the (admittedly unusual) case where any of these are expensive to construct.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions