Skip to content

Commit 5e19269

Browse files
zhangchenyi_dlneilzhuu
authored andcommitted
[Metax] fix build error of rejection
1 parent 05c01bd commit 5e19269

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

custom_ops/gpu_ops/sample_kernels/sampling.cuh

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -434,14 +434,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs,
434434
__syncthreads();
435435
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
436436

437-
#ifdef PADDLE_WITH_COREX
437+
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
438438
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
439439
temp_storage.block_prim.reduce_value_count)
440440
.Sum(probs_gt_pivot_1);
441441
#else
442442
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
443443
temp_storage.block_prim.reduce_value_count)
444-
.Sum(probs_gt_pivot_1);
444+
.Sum<VEC_SIZE>(probs_gt_pivot_1);
445445
#endif
446446
if (tx == 0) {
447447
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
@@ -573,14 +573,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs,
573573
__syncthreads();
574574
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
575575

576-
#ifdef PADDLE_WITH_COREX
576+
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
577577
aggregate_gt_pivot_1 +=
578578
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
579579
.Sum(probs_gt_pivot_1);
580580
#else
581581
aggregate_gt_pivot_1 +=
582582
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
583-
.Sum(probs_gt_pivot_1);
583+
.Sum<VEC_SIZE>(probs_gt_pivot_1);
584584
#endif
585585
if (tx == 0) {
586586
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
@@ -822,16 +822,17 @@ __global__ void TopKRenormProbKernel(DType* probs,
822822
#endif
823823
__syncthreads();
824824

825-
#ifdef PADDLE_WITH_COREX
826-
aggregate_gt_pivot_1 +=
827-
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
828-
temp_storage.block_prim.reduce_value_count)
829-
.Sum(probs_gt_pivot_1_pair);
825+
if defined (PADDLE_WITH_COREX)
826+
|| defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
827+
aggregate_gt_pivot_1 +=
828+
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
829+
temp_storage.block_prim.reduce_value_count)
830+
.Sum(probs_gt_pivot_1_pair);
830831
#else
831832
aggregate_gt_pivot_1 +=
832833
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
833834
temp_storage.block_prim.reduce_value_count)
834-
.Sum(probs_gt_pivot_1_pair);
835+
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
835836
#endif
836837
__syncthreads();
837838
}

0 commit comments

Comments
 (0)