Skip to content

Commit d5a9b75

Browse files
authored
fix cutlass ep (#5337)
1 parent 690bcb8 commit d5a9b75

File tree

1 file changed

+89
-70
lines changed

1 file changed

+89
-70
lines changed

custom_ops/gpu_ops/moe/moe_topk_select.cu

Lines changed: 89 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
1615
// Ignore CUTLASS warnings about type punning
1716
#pragma GCC diagnostic push
1817
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -39,20 +38,35 @@ void moe_topk_select_kernel(const T* input,
3938
const int64_t k,
4039
cudaStream_t stream,
4140
const bool apply_norm_weight = false,
42-
const bool enable_softmax_top_k_fused = false
43-
) {
41+
const bool enable_softmax_top_k_fused = false) {
4442
static constexpr int WARPS_PER_TB = 4;
4543

46-
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
47-
case N: { \
48-
if (apply_norm_weight) { \
49-
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
50-
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
51-
} else { \
52-
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
53-
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
54-
} \
55-
break; \
44+
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
45+
case N: { \
46+
if (apply_norm_weight) { \
47+
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
48+
input, \
49+
bias, \
50+
output, \
51+
indices, \
52+
source_row, \
53+
num_rows, \
54+
num_experts, \
55+
k, \
56+
stream); \
57+
} else { \
58+
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
59+
input, \
60+
bias, \
61+
output, \
62+
indices, \
63+
source_row, \
64+
num_rows, \
65+
num_experts, \
66+
k, \
67+
stream); \
68+
} \
69+
break; \
5670
}
5771
switch (num_experts) {
5872
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
@@ -68,56 +82,56 @@ void moe_topk_select_kernel(const T* input,
6882
static constexpr int TPB = 256;
6983
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
7084
if (!enable_softmax_top_k_fused) {
71-
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
72-
input, softmax, num_experts, num_rows);
73-
if (apply_norm_weight) {
74-
moe_top_k<T, TPB, true>
75-
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
76-
bias,
77-
output,
78-
indices,
79-
source_row,
80-
num_experts,
81-
k,
82-
num_rows);
83-
} else {
84-
moe_top_k<T, TPB, false>
85-
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
86-
bias,
87-
output,
88-
indices,
89-
source_row,
90-
num_experts,
91-
k,
92-
num_rows);
93-
}
94-
cudaGetLastError();
95-
}
96-
else {
97-
assert(k<=TPB);
98-
if (apply_norm_weight) {
99-
moe_softmax_top_k_fused<T, TPB, true>
100-
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
101-
bias,
102-
output,
103-
indices,
104-
source_row,
105-
num_experts,
106-
k,
107-
num_rows);
108-
} else {
109-
moe_softmax_top_k_fused<T, TPB, false>
110-
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
111-
bias,
112-
output,
113-
indices,
114-
source_row,
115-
num_experts,
116-
k,
117-
num_rows);
118-
}
85+
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
86+
input, softmax, num_experts, num_rows);
87+
if (apply_norm_weight) {
88+
moe_top_k<T, TPB, true>
89+
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
90+
softmax,
91+
bias,
92+
output,
93+
indices,
94+
source_row,
95+
num_experts,
96+
k,
97+
num_rows);
98+
} else {
99+
moe_top_k<T, TPB, false>
100+
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
101+
bias,
102+
output,
103+
indices,
104+
source_row,
105+
num_experts,
106+
k,
107+
num_rows);
108+
}
109+
cudaGetLastError();
110+
} else {
111+
assert(k <= TPB);
112+
if (apply_norm_weight) {
113+
moe_softmax_top_k_fused<T, TPB, true>
114+
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
115+
input,
116+
bias,
117+
output,
118+
indices,
119+
source_row,
120+
num_experts,
121+
k,
122+
num_rows);
123+
} else {
124+
moe_softmax_top_k_fused<T, TPB, false>
125+
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
126+
bias,
127+
output,
128+
indices,
129+
source_row,
130+
num_experts,
131+
k,
132+
num_rows);
133+
}
119134
}
120-
121135
}
122136
}
123137
}
@@ -146,6 +160,13 @@ std::vector<paddle::Tensor> MoETopKSelectKernel(
146160
auto topk_weights =
147161
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
148162

163+
// NOTE(sunxin): Avoid "invalid configuration argument" error caused by empty
164+
// tensors.
165+
if (gating_dims[0] == 0) {
166+
cudaGetLastError();
167+
return {topk_ids, topk_weights};
168+
}
169+
149170
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
150171
const int bytes = num_moe_inputs * sizeof(int);
151172

@@ -213,8 +234,7 @@ std::vector<std::vector<int64_t>> MoETopKSelectKernelInferShape(
213234
}
214235
const int num_rows = token_rows;
215236

216-
return {{num_rows, moe_topk},
217-
{num_rows, moe_topk}};
237+
return {{num_rows, moe_topk}, {num_rows, moe_topk}};
218238
}
219239

220240
std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
@@ -223,16 +243,15 @@ std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
223243
const int moe_topk,
224244
const bool apply_norm_weight,
225245
const bool enable_softmax_top_k_fused) {
226-
return {paddle::DataType::INT64,
227-
paddle::DataType::FLOAT32};
246+
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
228247
}
229248

230-
231249
PD_BUILD_STATIC_OP(moe_topk_select)
232250
.Inputs({"gating_logits", paddle::Optional("bias")})
233-
.Outputs({"topk_ids",
234-
"topk_weights"})
235-
.Attrs({"moe_topk:int", "apply_norm_weight:bool", "enable_softmax_top_k_fused:bool"})
251+
.Outputs({"topk_ids", "topk_weights"})
252+
.Attrs({"moe_topk:int",
253+
"apply_norm_weight:bool",
254+
"enable_softmax_top_k_fused:bool"})
236255
.SetKernelFn(PD_KERNEL(MoETopKSelectKernel))
237256
.SetInferShapeFn(PD_INFER_SHAPE(MoETopKSelectKernelInferShape))
238257
.SetInferDtypeFn(PD_INFER_DTYPE(MoETopKSelectKernelInferDtype));

0 commit comments

Comments
 (0)