1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15-
1615// Ignore CUTLASS warnings about type punning
1716#pragma GCC diagnostic push
1817#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -39,20 +38,35 @@ void moe_topk_select_kernel(const T* input,
3938 const int64_t k,
4039 cudaStream_t stream,
4140 const bool apply_norm_weight = false ,
42- const bool enable_softmax_top_k_fused = false
43- ) {
41+ const bool enable_softmax_top_k_fused = false ) {
4442 static constexpr int WARPS_PER_TB = 4 ;
4543
46- #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER (N ) \
47- case N: { \
48- if (apply_norm_weight) { \
49- topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true >( \
50- input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
51- } else { \
52- topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false >( \
53- input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
54- } \
55- break ; \
44+ #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER (N ) \
45+ case N: { \
46+ if (apply_norm_weight) { \
47+ topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true >( \
48+ input, \
49+ bias, \
50+ output, \
51+ indices, \
52+ source_row, \
53+ num_rows, \
54+ num_experts, \
55+ k, \
56+ stream); \
57+ } else { \
58+ topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false >( \
59+ input, \
60+ bias, \
61+ output, \
62+ indices, \
63+ source_row, \
64+ num_rows, \
65+ num_experts, \
66+ k, \
67+ stream); \
68+ } \
69+ break ; \
5670 }
5771 switch (num_experts) {
5872 LAUNCH_TOPK_GATING_SOFTMAX_HELPER (2 )
@@ -68,56 +82,56 @@ void moe_topk_select_kernel(const T* input,
6882 static constexpr int TPB = 256 ;
6983 const auto config_topk = Get1DBlocksAnd2DGridsMoe (num_rows);
7084 if (!enable_softmax_top_k_fused) {
71- moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0 , stream>>> (
72- input, softmax, num_experts, num_rows);
73- if (apply_norm_weight) {
74- moe_top_k<T, TPB, true >
75- <<<config_topk.block_per_grid, TPB, k * sizeof (T), stream>>> (softmax,
76- bias,
77- output,
78- indices,
79- source_row,
80- num_experts,
81- k,
82- num_rows);
83- } else {
84- moe_top_k<T, TPB, false >
85- <<<config_topk.block_per_grid, TPB, 0 , stream>>> (softmax,
86- bias,
87- output,
88- indices,
89- source_row,
90- num_experts,
91- k,
92- num_rows);
93- }
94- cudaGetLastError ();
95- }
96- else {
97- assert (k<=TPB);
98- if (apply_norm_weight) {
99- moe_softmax_top_k_fused<T, TPB, true >
100- <<<config_topk.block_per_grid, TPB, k * sizeof (T), stream>>> (input,
101- bias,
102- output,
103- indices,
104- source_row,
105- num_experts,
106- k,
107- num_rows);
108- } else {
109- moe_softmax_top_k_fused<T, TPB, false >
110- <<<config_topk.block_per_grid, TPB, 0 , stream>>> (input,
111- bias,
112- output,
113- indices,
114- source_row,
115- num_experts,
116- k,
117- num_rows);
118- }
85+ moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0 , stream>>> (
86+ input, softmax, num_experts, num_rows);
87+ if (apply_norm_weight) {
88+ moe_top_k<T, TPB, true >
89+ <<<config_topk.block_per_grid, TPB, k * sizeof (T), stream>>> (
90+ softmax,
91+ bias,
92+ output,
93+ indices,
94+ source_row,
95+ num_experts,
96+ k,
97+ num_rows);
98+ } else {
99+ moe_top_k<T, TPB, false >
100+ <<<config_topk.block_per_grid, TPB, 0 , stream>>> (softmax,
101+ bias,
102+ output,
103+ indices,
104+ source_row,
105+ num_experts,
106+ k,
107+ num_rows);
108+ }
109+ cudaGetLastError ();
110+ } else {
111+ assert (k <= TPB);
112+ if (apply_norm_weight) {
113+ moe_softmax_top_k_fused<T, TPB, true >
114+ <<<config_topk.block_per_grid, TPB, k * sizeof (T), stream>>> (
115+ input,
116+ bias,
117+ output,
118+ indices,
119+ source_row,
120+ num_experts,
121+ k,
122+ num_rows);
123+ } else {
124+ moe_softmax_top_k_fused<T, TPB, false >
125+ <<<config_topk.block_per_grid, TPB, 0 , stream>>> (input,
126+ bias,
127+ output,
128+ indices,
129+ source_row,
130+ num_experts,
131+ k,
132+ num_rows);
133+ }
119134 }
120-
121135 }
122136 }
123137}
@@ -146,6 +160,13 @@ std::vector<paddle::Tensor> MoETopKSelectKernel(
146160 auto topk_weights =
147161 GetEmptyTensor ({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
148162
163+ // NOTE(sunxin): Avoid "invalid configuration argument" error caused by empty
164+ // tensors.
165+ if (gating_dims[0 ] == 0 ) {
166+ cudaGetLastError ();
167+ return {topk_ids, topk_weights};
168+ }
169+
149170 const int num_moe_inputs = AlignTo16 (num_rows * moe_topk);
150171 const int bytes = num_moe_inputs * sizeof (int );
151172
@@ -213,8 +234,7 @@ std::vector<std::vector<int64_t>> MoETopKSelectKernelInferShape(
213234 }
214235 const int num_rows = token_rows;
215236
216- return {{num_rows, moe_topk},
217- {num_rows, moe_topk}};
237+ return {{num_rows, moe_topk}, {num_rows, moe_topk}};
218238}
219239
220240std::vector<paddle::DataType> MoETopKSelectKernelInferDtype (
@@ -223,16 +243,15 @@ std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
223243 const int moe_topk,
224244 const bool apply_norm_weight,
225245 const bool enable_softmax_top_k_fused) {
226- return {paddle::DataType::INT64,
227- paddle::DataType::FLOAT32};
246+ return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
228247}
229248
230-
231249PD_BUILD_STATIC_OP (moe_topk_select)
232250 .Inputs({" gating_logits" , paddle::Optional (" bias" )})
233- .Outputs({" topk_ids" ,
234- " topk_weights" })
235- .Attrs({" moe_topk:int" , " apply_norm_weight:bool" , " enable_softmax_top_k_fused:bool" })
251+ .Outputs({" topk_ids" , " topk_weights" })
252+ .Attrs({" moe_topk:int" ,
253+ " apply_norm_weight:bool" ,
254+ " enable_softmax_top_k_fused:bool" })
236255 .SetKernelFn(PD_KERNEL(MoETopKSelectKernel))
237256 .SetInferShapeFn(PD_INFER_SHAPE(MoETopKSelectKernelInferShape))
238257 .SetInferDtypeFn(PD_INFER_DTYPE(MoETopKSelectKernelInferDtype));
0 commit comments