@@ -434,14 +434,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs,
434434 __syncthreads ();
435435 aggregate_gt_pivot_0 = temp_storage.block_aggregate .pair ;
436436
437- #ifdef PADDLE_WITH_COREX
437+ #if defined( PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
438438 aggregate_gt_pivot_1 += BlockReduce<ValueCount<float >, BLOCK_THREADS>(
439439 temp_storage.block_prim .reduce_value_count )
440440 .Sum (probs_gt_pivot_1);
441441#else
442442 aggregate_gt_pivot_1 += BlockReduce<ValueCount<float >, BLOCK_THREADS>(
443443 temp_storage.block_prim .reduce_value_count )
444- .Sum (probs_gt_pivot_1);
444+ .Sum <VEC_SIZE> (probs_gt_pivot_1);
445445#endif
446446 if (tx == 0 ) {
447447 temp_storage.block_aggregate .pair = aggregate_gt_pivot_1;
@@ -573,14 +573,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs,
573573 __syncthreads ();
574574 aggregate_gt_pivot_0 = temp_storage.block_aggregate .value ;
575575
576- #ifdef PADDLE_WITH_COREX
576+ #if defined( PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
577577 aggregate_gt_pivot_1 +=
578578 BlockReduce<float , BLOCK_THREADS>(temp_storage.block_prim .reduce )
579579 .Sum (probs_gt_pivot_1);
580580#else
581581 aggregate_gt_pivot_1 +=
582582 BlockReduce<float , BLOCK_THREADS>(temp_storage.block_prim .reduce )
583- .Sum (probs_gt_pivot_1);
583+ .Sum <VEC_SIZE> (probs_gt_pivot_1);
584584#endif
585585 if (tx == 0 ) {
586586 temp_storage.block_aggregate .value = aggregate_gt_pivot_1;
@@ -822,7 +822,7 @@ __global__ void TopKRenormProbKernel(DType* probs,
822822#endif
823823 __syncthreads ();
824824
825- #ifdef PADDLE_WITH_COREX
825+ #if defined( PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
826826 aggregate_gt_pivot_1 +=
827827 BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
828828 temp_storage.block_prim .reduce_value_count )
@@ -831,7 +831,7 @@ __global__ void TopKRenormProbKernel(DType* probs,
831831 aggregate_gt_pivot_1 +=
832832 BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
833833 temp_storage.block_prim .reduce_value_count )
834- .Sum (probs_gt_pivot_1_pair);
834+ .Sum <VEC_SIZE> (probs_gt_pivot_1_pair);
835835#endif
836836 __syncthreads ();
837837 }
0 commit comments