diff --git a/runtime/include/executor/ocl/kernel_utils.cl b/runtime/include/executor/ocl/kernel_utils.cl index 097748c..7387809 100644 --- a/runtime/include/executor/ocl/kernel_utils.cl +++ b/runtime/include/executor/ocl/kernel_utils.cl @@ -340,6 +340,20 @@ void _change_hab_habpim(__global uint8_t* __restrict__ pim_ctr, uint64_t offset R_CMD(&pim_ctr[addr + offset]); } +void _change_gemv_hab_habpim(__global uint8_t* __restrict__ pim_ctr, uint64_t offset +#ifdef EMULATOR + , + __global PimMemTracer* emulator_trace +#endif + ) +{ + uint64_t addr; + addr = addr_gen_(get_group_id(0), 0, 0, 0, 0x3fff, 0x0); + W_CMD_R_C(&pim_ctr[addr + offset], gemv_hab_to_hab_pim + offset); + R_CMD(&pim_ctr[addr + offset]); + B_CMD(1); +} + void _change_habpim_hab(__global uint8_t* __restrict__ pim_ctr, uint64_t offset #ifdef EMULATOR , @@ -380,6 +394,19 @@ void _park_out(__global uint8_t* __restrict__ pim_ctr, int gidx, int num_ba, uin W_CMD(&pim_ctr[addr + offset]); } +uint64_t addr_gen_s(uint32_t chan, uint32_t rank, uint32_t bankgroup, uint32_t bank, uint32_t row, uint32_t col, uint32_t offset) +{ + uint32_t offset_size = 1 << vega20_pbi.num_offset_bit; + uint32_t col_size = vega20_pbi.num_col / vega20_pbi.bl; + + uint32_t offset_s = offset % offset_size; + uint32_t new_col = col + offset / offset_size; + uint32_t col_s = new_col % col_size; + uint32_t row_s = row + new_col / col_size; + + return addr_gen_(chan, rank, bankgroup, bank, row_s, col_s) + offset_s; +} + #ifdef EMULATOR #define park_in(a, b, c, d) _park_in(a, b, c, d, emulator_trace) #define change_sb_hab(a, b) _change_sb_hab(a, b, emulator_trace) @@ -387,6 +414,7 @@ void _park_out(__global uint8_t* __restrict__ pim_ctr, int gidx, int num_ba, uin #define program_crf_mod(a, b, c, d) _program_crf_mod(a, b, c, d, emulator_trace) #define program_srf(a, b, c) _program_srf(a, b, c, emulator_trace) #define change_hab_habpim(a, b) _change_hab_habpim(a, b, emulator_trace) +#define change_gemv_hab_habpim(a, b) _change_gemv_hab_habpim(a, b, emulator_trace) #define change_habpim_hab(a, b) _change_habpim_hab(a, b, emulator_trace) #define change_hab_sb(a, b, c) _change_hab_sb(a, b, c, emulator_trace) #define park_out(a, b, c, d) _park_out(a, b, c, d, emulator_trace) @@ -398,6 +426,7 @@ void _park_out(__global uint8_t* __restrict__ pim_ctr, int gidx, int num_ba, uin #define program_crf_mod(a, b, c, d) _program_crf_mod(a, b, c, d) #define program_srf(a, b, c) _program_srf(a, b, c) #define change_hab_habpim(a, b) _change_hab_habpim(a, b) +#define change_gemv_hab_habpim(a, b) _change_gemv_hab_habpim(a, b) #define change_habpim_hab(a, b) _change_habpim_hab(a, b) #define change_hab_sb(a, b, c) _change_hab_sb(a, b, c) #define park_out(a, b, c, d) _park_out(a, b, c, d) diff --git a/runtime/include/executor/ocl/pim_gemm.cl b/runtime/include/executor/ocl/pim_gemm.cl index c8c6a2a..7944613 100644 --- a/runtime/include/executor/ocl/pim_gemm.cl +++ b/runtime/include/executor/ocl/pim_gemm.cl @@ -504,7 +504,7 @@ __kernel void pim_aligned_gemm_bias_relu_fp16( int gidx = get_local_id(0) >> 1; uint64_t offset = w_idx << 4; uint64_t addr; - int gemv_cnt = 0; + uint64_t offset_unsafe; #endif #if PARK_IN @@ -528,107 +528,47 @@ __kernel void pim_aligned_gemm_bias_relu_fp16( barrier(CLK_GLOBAL_MEM_FENCE); #if COMPUTE_GEMM if (get_local_id(0) < 16) { - for (int i = 0; i < iter_cnt; i++) { + for (int b = 0; b < iter_cnt; b++) { /* change HAB mode to HAB_PIM mode */ for (int in_idx = 0; in_idx < inout_h; in_idx++) { for (int o_idx = 0; o_idx < n_out_tile; o_idx++) { - addr = addr_gen_(ch, 0, 0, 0, 0x3fff, 0x0); - W_CMD_R_C(&pim_ctr[addr + offset], gemv_hab_to_hab_pim + offset); - R_CMD(&pim_ctr[addr + offset]); - B_CMD(1); - - uint64_t i_offset = gemv_cnt * (n_in_tile << grf_shift); - int r_offset = (o_idx * n_in_tile) >> 1; - - for (int i_idx = 0; i_idx < n_in_tile; i_idx += 2) { - /* write grf_A from WRIO */ - uint64_t i_addr = (i_offset + ((i_idx << grf_shift) + gidx)) << trans_shift; - addr = addr_gen_(ch, 0, 0, 0, 0x3fff, 0x8 + gidx); - W_CMD_R(&pim_ctr[addr + offset], &input[i_addr + offset]); - R_CMD(&pim_ctr[addr + offset]); - B_CMD(1); - - even_row = ((i_idx >> 1) + r_offset) << 1; - odd_row = even_row + 1; - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx + 24); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx + 24); - R_CMD(&weight[addr + offset]); - B_CMD(1); + change_gemv_hab_habpim(pim_ctr, offset); +#pragma unroll + for (int bk = 0; bk < 2; bk++) { +#pragma unroll + for (int i = 0; i < n_in_tile / 2; i++) { + // write grf_A from WRIO + uint64_t i_addr = + (((b * inout_h * n_in_tile + in_idx * n_in_tile + i*2 + bk) + << grf_shift) << trans_shift) + get_local_id(0) * 16; + addr = addr_gen_s(ch, 0, 0, bk, 0x3fff, 0x8, get_local_id(0) * 16); + W_CMD_R(&pim_ctr[addr], &input[i_addr]); + R_CMD(&pim_ctr[addr]); + B_CMD(1); + + offset_unsafe = (((b * n_out_tile * n_in_tile + o_idx * n_in_tile + i*2) << grf_shift) << trans_shift) * n_in_tile / 2 + get_local_id(0) * 16; + for (int j = 0; j < 8; j++) { + addr = addr_gen_s(ch, 0, 0, bk, j / 4, (j % 4) * 8, offset_unsafe); + R_CMD(&weight[addr]); + } + B_CMD(1); + } } - for (int i_idx = 1; i_idx < n_in_tile; i_idx += 2) { - uint64_t i_addr = (i_offset + ((i_idx << grf_shift) + gidx)) << trans_shift; - addr = addr_gen_(ch, 0, 0, 1, 0x3fff, 0x8 + gidx); - W_CMD_R(&pim_ctr[addr + offset], &input[i_addr + offset]); - R_CMD(&pim_ctr[addr + offset]); - B_CMD(1); - - even_row = ((i_idx >> 1) + r_offset) << 1; - odd_row = even_row + 1; - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx + 24); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx + 24); - R_CMD(&weight[addr + offset]); - B_CMD(1); - } - loc = (gemv_cnt * n_out_tile << grf_shift) + (o_idx << grf_shift) + gidx; - row = loc >> col_shift; - col = loc % num_col; + offset_unsafe = (((b * inout_h * n_out_tile + in_idx * n_out_tile + o_idx) << grf_shift) << trans_shift) + get_local_id(0) * 16 + + gidx; // pipeline delay // FIX : If alu is in operation, NOP should be added. - addr = addr_gen_(ch, 0, 0, 1, row, col); - W_CMD(&pim_partial_sum[addr + offset]); - W_CMD(&pim_partial_sum[addr + offset]); - R_CMD(&pim_partial_sum[addr + offset]); + addr = addr_gen_s(ch, 0, 0, 1, 0, 0, offset_unsafe); + W_CMD(&pim_partial_sum[addr]); + W_CMD(&pim_partial_sum[addr]); + R_CMD(&pim_partial_sum[addr]); B_CMD(1); change_habpim_hab(pim_ctr, offset); } - gemv_cnt++; } - weight += (in_w * out_w << 1); } } #endif @@ -666,8 +606,6 @@ __kernel void pim_aligned_gemm_bias_relu_fp16( half t_output; #endif - gemv_cnt = 0; - for (int i = 0; i < iter_cnt; i++) { for (int in_idx = 0; in_idx < inout_h; in_idx++) { for (int oi = 0; oi < n_out_tile; oi++) { @@ -675,18 +613,17 @@ __kernel void pim_aligned_gemm_bias_relu_fp16( /* out_idx = oi * out_per_tile + t_idx; */ out_idx = (oi << 12) + t_idx; if (out_idx < out_w) { - li = gemv_cnt * n_out_tile + oi; + li = (i * inout_h + in_idx) * n_out_tile + oi; row = li >> 2; col = get_local_id(0) % 8 + ((li % 4) << 3); addr = addr_gen_(ch, 0, bg, ba, row, col); t_output = 0; #if NVIDIA_GPU for (int ti = 0; ti < 16; ti++) { - // t_output += ((half*)pim_partial_sum)[(addr >> 1) + ti]; t_output += vload_half(addr + (ti<<1) , pim_partial_sum); } - out_offset = gemv_cnt * out_w + out_idx; - if (is_bias) t_output += vload_half(out_offset << 1 , bias) + out_offset = (i * inout_h + in_idx) * out_w + out_idx; + if (is_bias) t_output += vload_half(out_offset << 1 , bias); if (is_relu) if (t_output < (float)0.) t_output = (float)0.; vstore_half(t_output , out_offset << 1 , output); @@ -694,7 +631,7 @@ __kernel void pim_aligned_gemm_bias_relu_fp16( for (int ti = 0; ti < 16; ti++) { t_output += ((half*)pim_partial_sum)[(addr >> 1) + ti]; } - out_offset = gemv_cnt * out_w + out_idx; + out_offset = (i * inout_h + in_idx) * out_w + out_idx; if (is_bias) t_output += ((half*)bias)[out_offset]; if (is_relu) if (t_output < (half)0.) t_output = (half)0.; @@ -702,7 +639,6 @@ __kernel void pim_aligned_gemm_bias_relu_fp16( #endif } } - gemv_cnt++; } } #endif @@ -740,7 +676,7 @@ __kernel void pim_aligned_gemm_bias_relu_8tile_fp16( int gidx = get_local_id(0) >> 1; uint64_t offset = w_idx << 4; uint64_t addr; - int gemv_cnt = 0; + uint64_t offset_unsafe; #endif #if PARK_IN @@ -764,109 +700,47 @@ __kernel void pim_aligned_gemm_bias_relu_8tile_fp16( barrier(CLK_GLOBAL_MEM_FENCE); #if COMPUTE_GEMM if (get_local_id(0) < 16) { - for (int i = 0; i < iter_cnt; i++) { + for (int b = 0; b < iter_cnt; b++) { /* change HAB mode to HAB_PIM mode */ for (int in_idx = 0; in_idx < inout_h; in_idx++) { for (int o_idx = 0; o_idx < n_out_tile; o_idx++) { - addr = addr_gen_(ch, 0, 0, 0, 0x3fff, 0x0); - W_CMD_R_C(&pim_ctr[addr + offset], gemv_hab_to_hab_pim + offset); - R_CMD(&pim_ctr[addr + offset]); - B_CMD(1); - - uint64_t i_offset = gemv_cnt * (n_in_tile << grf_shift); - int r_offset = (o_idx * n_in_tile) >> 1; - + change_gemv_hab_habpim(pim_ctr, offset); #pragma unroll - for (int i = 0, i_idx = 0; i < 4; i++, i_idx += 2) { - /* write grf_A from WRIO */ - uint64_t i_addr = (i_offset + ((i_idx << grf_shift) + gidx)) << trans_shift; - addr = addr_gen_(ch, 0, 0, 0, 0x3fff, 0x8 + gidx); - W_CMD_R(&pim_ctr[addr + offset], &input[i_addr + offset]); - R_CMD(&pim_ctr[addr + offset]); - B_CMD(1); - - even_row = ((i_idx >> 1) + r_offset) << 1; - odd_row = even_row + 1; - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, even_row, gidx + 24); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 0, odd_row, gidx + 24); - R_CMD(&weight[addr + offset]); - B_CMD(1); - } - + for (int bk = 0; bk < 2; bk++) { #pragma unroll - for (int i = 0, i_idx = 1; i < 4; i++, i_idx += 2) { - uint64_t i_addr = (i_offset + ((i_idx << grf_shift) + gidx)) << trans_shift; - addr = addr_gen_(ch, 0, 0, 1, 0x3fff, 0x8 + gidx); - W_CMD_R(&pim_ctr[addr + offset], &input[i_addr + offset]); - R_CMD(&pim_ctr[addr + offset]); - B_CMD(1); - - even_row = ((i_idx >> 1) + r_offset) << 1; - odd_row = even_row + 1; - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, even_row, gidx + 24); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx + 8); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx + 16); - R_CMD(&weight[addr + offset]); - - addr = addr_gen_(ch, 0, 0, 1, odd_row, gidx + 24); - R_CMD(&weight[addr + offset]); - B_CMD(1); + for (int i = 0; i < n_in_tile / 2; i++) { + // write grf_A from WRIO + uint64_t i_addr = + (((b * inout_h * n_in_tile + in_idx * n_in_tile + i*2 + bk) + << grf_shift) << trans_shift) + get_local_id(0) * 16; + addr = addr_gen_s(ch, 0, 0, bk, 0x3fff, 0x8, get_local_id(0) * 16); + W_CMD_R(&pim_ctr[addr], &input[i_addr]); + R_CMD(&pim_ctr[addr]); + B_CMD(1); + + offset_unsafe = (((b * n_out_tile * n_in_tile + o_idx * n_in_tile + i*2) << grf_shift) << trans_shift) * n_in_tile / 2 + get_local_id(0) * 16; + for (int j = 0; j < 8; j++) { + addr = addr_gen_s(ch, 0, 0, bk, j / 4, (j % 4) * 8, offset_unsafe); + R_CMD(&weight[addr]); + } + B_CMD(1); + } } - loc = (gemv_cnt * n_out_tile << grf_shift) + (o_idx << grf_shift) + gidx; - row = loc >> col_shift; - col = loc % num_col; + + offset_unsafe = (((b * inout_h * n_out_tile + in_idx * n_out_tile + o_idx) << grf_shift) << trans_shift) + get_local_id(0) * 16 + + gidx; // pipeline delay // FIX : If alu is in operation, NOP should be added. - addr = addr_gen_(ch, 0, 0, 1, row, col); - W_CMD(&pim_partial_sum[addr + offset]); - W_CMD(&pim_partial_sum[addr + offset]); - R_CMD(&pim_partial_sum[addr + offset]); + addr = addr_gen_s(ch, 0, 0, 1, 0, 0, offset_unsafe); + W_CMD(&pim_partial_sum[addr]); + W_CMD(&pim_partial_sum[addr]); + R_CMD(&pim_partial_sum[addr]); B_CMD(1); change_habpim_hab(pim_ctr, offset); } - gemv_cnt++; } - weight += (in_w * out_w << 1); } } #endif @@ -882,7 +756,7 @@ __kernel void pim_aligned_gemm_bias_relu_8tile_fp16( park_out(pim_ctr, gidx, num_ba, offset); } #endif - + #ifdef EMULATOR if (get_group_id(0) == 0 && get_local_id(0) == 0) { frd_size[0] = emulator_trace->g_ridx[0]; @@ -901,7 +775,6 @@ __kernel void pim_aligned_gemm_bias_relu_8tile_fp16( #else half t_output; #endif - gemv_cnt = 0; for (int i = 0; i < iter_cnt; i++) { for (int in_idx = 0; in_idx < inout_h; in_idx++) { @@ -910,7 +783,7 @@ __kernel void pim_aligned_gemm_bias_relu_8tile_fp16( /* out_idx = oi * out_per_tile + t_idx; */ out_idx = (oi << 12) + t_idx; if (out_idx < out_w) { - li = gemv_cnt * n_out_tile + oi; + li = (i * inout_h + in_idx) * n_out_tile + oi; row = li >> 2; col = get_local_id(0) % 8 + ((li % 4) << 3); addr = addr_gen_(ch, 0, bg, ba, row, col); @@ -919,16 +792,16 @@ __kernel void pim_aligned_gemm_bias_relu_8tile_fp16( for (int ti = 0; ti < 16; ti++) { t_output += vload_half(addr + (ti<<1) , pim_partial_sum); } - out_offset = gemv_cnt * out_w + out_idx; + out_offset = (i * inout_h + in_idx) * out_w + out_idx; if (is_bias) t_output += vload_half(out_offset << 1 , bias); if (is_relu) - if (t_output < (half)0.) t_output = (half)0.; + if (t_output < (float)0.) t_output = (float)0.; vstore_half(t_output , out_offset << 1 , output); #else for (int ti = 0; ti < 16; ti++) { t_output += ((half*)pim_partial_sum)[(addr >> 1) + ti]; } - out_offset = gemv_cnt * out_w + out_idx; + out_offset = (i * inout_h + in_idx) * out_w + out_idx; if (is_bias) t_output += ((half*)bias)[out_offset]; if (is_relu) if (t_output < (half)0.) t_output = (half)0.; @@ -936,7 +809,6 @@ __kernel void pim_aligned_gemm_bias_relu_8tile_fp16( #endif } } - gemv_cnt++; } } #endif