From 3a14754158398e06c8cebcd439dca41e229fb9ed Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 17 Jul 2025 15:42:11 +0800 Subject: [PATCH 1/9] feat: move init_req_to_token_indexes and copy_kv_index_to_req to alloc fun --- lightllm/common/basemodel/basemodel.py | 66 +------------------ lightllm/common/basemodel/cuda_graph.py | 4 +- lightllm/common/infer_utils.py | 4 +- lightllm/common/mem_manager.py | 28 +++++++- lightllm/common/req_manager.py | 1 + .../server/router/model_infer/infer_batch.py | 6 +- .../generic_padded_pre_process.py | 8 ++- .../mode_backend/generic_pre_process.py | 6 +- .../benchmark/static_inference/model_infer.py | 6 +- .../static_inference/model_infer_mtp.py | 8 ++- 10 files changed, 58 insertions(+), 79 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index ca20a7551..86e204f45 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -12,9 +12,7 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager -from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.quantization import Quantcfg @@ -333,14 +331,6 @@ def _prefill( model_input: ModelInput, ): infer_state = self._create_inferstate(model_input) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input.b_req_idx, - model_input.b_seq_len, - infer_state.b_ready_cache_len, - model_input.max_len_in_batch, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, model_input.input_ids) return self._context_forward(model_input.input_ids, infer_state) @@ -361,12 +351,6 @@ def _decode( find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -382,12 +366,6 @@ def _decode( ) else: infer_state = self._create_inferstate(model_input) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, model_input.input_ids) model_output = self._token_forward(model_input.input_ids, infer_state) @@ -472,25 +450,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids infer_state0 = self._create_inferstate(model_input0, 0) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input0.b_req_idx, - model_input0.b_seq_len, - infer_state0.b_ready_cache_len, - model_input0.max_len_in_batch, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, input_ids0) infer_state1 = self._create_inferstate(model_input1, 1) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input1.b_req_idx, - model_input1.b_seq_len, - infer_state1.b_ready_cache_len, - model_input1.max_len_in_batch, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, input_ids1) model_output0, model_output1 = self._overlap_tpsp_context_forward( @@ -532,20 +494,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state0.b_req_idx, - infer_state0.b_seq_len, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) infer_state1 = self._create_inferstate(padded_model_input1, 1) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state1.b_req_idx, - infer_state1.b_seq_len, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -570,20 +520,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) else: infer_state0 = self._create_inferstate(model_input0, 0) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state0.b_req_idx, - infer_state0.b_seq_len, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, model_input0.input_ids) infer_state1 = self._create_inferstate(model_input1, 1) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state1.b_req_idx, - infer_state1.b_seq_len, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, model_input1.input_ids) model_output0, model_output1 = self._overlap_tpsp_token_forward( @@ -684,10 +622,12 @@ def _check_max_len_infer(self): logger.info("begin check max_len infer") dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + mem_indexes = self.mem_manager.alloc( + len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() total_token_num = self.batch_max_tokens b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") model_input = ModelInput( diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 07792865e..db77298c7 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -196,13 +196,13 @@ def warmup(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() model_input = ModelInput( batch_size=batch_size, @@ -252,13 +252,13 @@ def warmup_overlap(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index da2f35e08..c3f980ecd 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,6 +1,4 @@ -def init_req_to_token_indexes( - req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index -): +def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index): start_index = 0 b_seq_len_numpy = b_seq_len.cpu().numpy() b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..aea5bc405 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -12,6 +12,8 @@ from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.infer_utils import init_req_to_token_indexes +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req logger = init_logger(__name__) @@ -52,6 +54,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size + self.req_to_token_indexs = None def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -243,7 +246,9 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: + def alloc( + self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ) -> torch.Tensor: if need_size > self.mark_end - self.mark_start: logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") assert False, "error alloc state" @@ -255,8 +260,29 @@ def alloc(self, need_size) -> torch.Tensor: self.can_use_mem_size -= need_size self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.req_to_token_indexs is not None: + assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided" + if is_prefill: + init_req_to_token_indexes( + self.req_to_token_indexs, + b_req_idx, + b_seq_len, + b_ready_cache_len, + ans, + ) + else: + copy_kv_index_to_req( + self.req_to_token_indexs, + b_req_idx.cuda(), + b_seq_len.cuda(), + ans.cuda(), + ) return ans + def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): + self.req_to_token_indexs[req_idx, start:end] = values + def free(self, free_index: Union[torch.Tensor, List[int]]): """_summary_ diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index dcd1b3072..b22b7f8ab 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -62,6 +62,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_to_token_indexs = torch.zeros( (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" ) + mem_manager.req_to_token_indexs = self.req_to_token_indexs self.mem_manager = mem_manager self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 01ae6c9c5..20b9c0e27 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -340,7 +340,9 @@ def _match_radix_cache(self): self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len # 从 cpu 到 gpu 是流内阻塞操作 - g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor + g_infer_context.req_manager.mem_manager.set_prefix_cache_to_req( + self.req_idx, 0, ready_cache_len, value_tensor + ) self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 @@ -458,7 +460,7 @@ def diverse_copy(self, req_manager, is_prefill): req = g_infer_context.requests_mapping[req_id] req.finish_status.set_status(FinishStatus.NO_FINISH) input_len = req.get_chuncked_input_token_len() - req_manager.req_to_token_indexs[req.req_idx][prefix_len:input_len] = cache_token_id + req_manager.mem_manager.set_prefix_cache_to_req(req.req_idx, prefix_len, input_len, cache_token_id) assert input_len == pre_input_len diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 10090a576..5059ad27f 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -78,7 +78,9 @@ def padded_prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True + ) g_infer_state_lock.release() if padded_req_num > 0: @@ -163,7 +165,9 @@ def padded_prepare_decode_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len + ) g_infer_state_lock.release() if padded_req_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index d5bba1ae5..1c81ebfcd 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -56,7 +56,9 @@ def prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ) g_infer_state_lock.release() model_input = ModelInput( @@ -112,7 +114,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() model_input = ModelInput( diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b7c07d17a..58a2c44a0 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -258,7 +258,9 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]) + mem_indexes = model_part.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") rank_id = model_kvargs["rank_id"] @@ -321,7 +323,7 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]) + mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len).cuda() max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index ba90e709b..9d684fd27 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -124,7 +124,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() # Main model Prefill model_input = ModelInput( batch_size=batch_size, @@ -191,7 +193,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len + ).cuda() model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1), From 9369d2d864e6b94fe89b5b3cd4d60bd910975666 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 31 Jul 2025 15:11:44 +0800 Subject: [PATCH 2/9] feat: add page_size_variable mode for fa3 backend --- lightllm/common/basemodel/basemodel.py | 62 +++ lightllm/common/infer_utils.py | 4 +- lightllm/common/mem_manager.py | 20 - lightllm/common/mem_utils.py | 4 + .../common/page_size_variable_mem_manager.py | 173 +++++++ lightllm/common/req_manager.py | 10 +- .../llama/flashattention_infer_struct.py | 42 +- .../layer_infer/transformer_layer_infer.py | 76 ++- .../dynamic_prompt/paged_radix_cache.py | 462 ++++++++++++++++++ .../router/dynamic_prompt/radix_cache.py | 4 +- .../model_infer/mode_backend/base_backend.py | 4 +- .../generic_padded_pre_process.py | 6 +- .../mode_backend/generic_pre_process.py | 6 +- lightllm/utils/envs_utils.py | 9 + 14 files changed, 835 insertions(+), 47 deletions(-) create mode 100755 lightllm/common/page_size_variable_mem_manager.py create mode 100644 lightllm/server/router/dynamic_prompt/paged_radix_cache.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 86e204f45..b746e322c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -12,7 +12,9 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager +from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.quantization import Quantcfg @@ -331,6 +333,14 @@ def _prefill( model_input: ModelInput, ): infer_state = self._create_inferstate(model_input) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input.b_req_idx, + model_input.b_seq_len, + infer_state.b_ready_cache_len, + model_input.max_len_in_batch, + infer_state.mem_index, + ) infer_state.init_some_extra_state(self, model_input.input_ids) return self._context_forward(model_input.input_ids, infer_state) @@ -351,6 +361,12 @@ def _decode( find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.mem_index, + ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -366,6 +382,12 @@ def _decode( ) else: infer_state = self._create_inferstate(model_input) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.mem_index, + ) infer_state.init_some_extra_state(self, model_input.input_ids) model_output = self._token_forward(model_input.input_ids, infer_state) @@ -450,9 +472,25 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids infer_state0 = self._create_inferstate(model_input0, 0) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input0.b_req_idx, + model_input0.b_seq_len, + infer_state0.b_ready_cache_len, + model_input0.max_len_in_batch, + infer_state0.mem_index, + ) infer_state0.init_some_extra_state(self, input_ids0) infer_state1 = self._create_inferstate(model_input1, 1) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input1.b_req_idx, + model_input1.b_seq_len, + infer_state1.b_ready_cache_len, + model_input1.max_len_in_batch, + infer_state1.mem_index, + ) infer_state1.init_some_extra_state(self, input_ids1) model_output0, model_output1 = self._overlap_tpsp_context_forward( @@ -494,8 +532,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) infer_state1 = self._create_inferstate(padded_model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -520,8 +570,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) else: infer_state0 = self._create_inferstate(model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) infer_state0.init_some_extra_state(self, model_input0.input_ids) infer_state1 = self._create_inferstate(model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) infer_state1.init_some_extra_state(self, model_input1.input_ids) model_output0, model_output1 = self._overlap_tpsp_token_forward( diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index c3f980ecd..da2f35e08 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,4 +1,6 @@ -def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index): +def init_req_to_token_indexes( + req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index +): start_index = 0 b_seq_len_numpy = b_seq_len.cpu().numpy() b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index aea5bc405..f3cf0419d 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -12,8 +12,6 @@ from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id -from lightllm.common.infer_utils import init_req_to_token_indexes -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req logger = init_logger(__name__) @@ -260,24 +258,6 @@ def alloc( self.can_use_mem_size -= need_size self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.req_to_token_indexs is not None: - assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided" - if is_prefill: - init_req_to_token_indexes( - self.req_to_token_indexs, - b_req_idx, - b_seq_len, - b_ready_cache_len, - ans, - ) - else: - copy_kv_index_to_req( - self.req_to_token_indexs, - b_req_idx.cuda(), - b_seq_len.cuda(), - ans.cuda(), - ) return ans def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): diff --git a/lightllm/common/mem_utils.py b/lightllm/common/mem_utils.py index dfb8e849d..5f3ee6164 100644 --- a/lightllm/common/mem_utils.py +++ b/lightllm/common/mem_utils.py @@ -4,6 +4,7 @@ from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager +from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -28,6 +29,9 @@ def select_mem_manager_class(mode): elif "export_fp8kv_calibration" in mode: memory_manager_class = ExportCalibrationMemoryManager logger.info("Using mode export fp8kv calibration") + elif "page_size_variable" in mode: + memory_manager_class = PageSizeVariableMemoryManager + logger.info("Page size will be variable") else: memory_manager_class = MemoryManager logger.info("Model kv cache using mode normal") diff --git a/lightllm/common/page_size_variable_mem_manager.py b/lightllm/common/page_size_variable_mem_manager.py new file mode 100755 index 000000000..095648c01 --- /dev/null +++ b/lightllm/common/page_size_variable_mem_manager.py @@ -0,0 +1,173 @@ +import torch +import numpy as np +from .mem_manager import MemoryManager +from typing import List, Union +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_page_size + + +def cdiv(a, b): + return (a + b - 1) // b + + +logger = init_logger(__name__) + + +class PageSizeVariableMemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + self.req_to_page_indexs = None + page_size = get_page_size() + self.page_idx_pool = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), + dtype=dtype, + device="cuda", + ) + + # 要求长度必须是page_size的整数倍,page内token索引必须连续 + def check_cache_page_valid(self, values: torch.Tensor): + end = len(values) + assert end % self.page_size == 0, "Values length must be a multiple of page size" + total_pages = end // self.page_size + for page_idx in range(total_pages): + values_start = page_idx * self.page_size + values_end = min((page_idx + 1) * self.page_size, end) + page_token_idxs = values[values_start:values_end] + if len(page_token_idxs) > 1: + expected_idxs = torch.arange( + page_token_idxs[0], + page_token_idxs[0] + len(page_token_idxs), + dtype=page_token_idxs.dtype, + device=page_token_idxs.device, + ) + if not torch.equal(page_token_idxs, expected_idxs): + return False + return True + + def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): + # assert self.check_cache_page_valid(values), "Values must be valid for page size" + page_size = get_page_size() + self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size + self.req_to_token_indexs[req_idx, start:end] = values + + def expand_by_page_size(self, b_token_len, page_size): + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 + b_page_len = cdiv(b_token_len, page_size) + need_pages_num = b_page_len.sum() + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, b_page_len, p_token_len + + def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill): + if is_prefill: + b_req_idx = b_req_idx.cuda() + b_seq_len = b_seq_len.cuda() + b_ready_cache_len = b_ready_cache_len.cuda() + + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size) + if self.can_use_page_size < total_pages_needed: + raise RuntimeError( + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}" + ) + + allocated_pages = self.page_idx_pool[ + self.mark_page_start : self.mark_page_start + total_pages_needed + ].cuda() + + def get_offsets_by_length(b_len, max_len): + # 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4] + offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device) + offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1) + return torch.masked_select(offsets, offset_mask) + + page_offsets = get_offsets_by_length(b_page_len, b_page_len.max()) + token_offsets = get_offsets_by_length(p_token_len, page_size) + + # 更新req_to_page_indexs, b_ready_cache_len必整除page_size + page_starts = b_ready_cache_len // page_size + req_id = torch.repeat_interleave( + torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len + ) + self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages + + self.mark_page_start += total_pages_needed + self.can_use_page_size -= total_pages_needed + page_bases = allocated_pages * page_size + return torch.repeat_interleave(page_bases, p_token_len) + token_offsets + else: + b_seq_len = b_seq_len.cuda() + b_req_idx = b_req_idx.cuda() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = need_new_page_mask.sum() + if self.can_use_page_size < new_pages_num: + raise RuntimeError( + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}" + ) + + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda() + self.mark_page_start += new_pages_num + self.can_use_page_size -= new_pages_num + token_idxs[need_new_page_mask] = new_pages * page_size + + # 需要更新req_to_page_indexs + new_page_req_indices = b_req_idx[need_new_page_mask] + page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size + self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages + + mask = ~need_new_page_mask + if mask.any(): + seq_lens = b_seq_len[mask] + token_idxs[mask] = ( + self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size + + (seq_lens - 1) % page_size + ) + return token_idxs + + def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor: + page_size = get_page_size() + token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill) + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return token_idxs + + def free(self, free_index: Union[torch.Tensor, List[int]]): + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + page_size = get_page_size() + if isinstance(free_index, list): + free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True) + + if len(free_index) == 0: + return + + page_indices = free_index // page_size + unique_pages = torch.unique(page_indices) + for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序 + self.mark_page_start -= 1 + self.page_idx_pool[self.mark_page_start] = page_idx + self.can_use_page_size += 1 + + return + + def free_all(self): + super().free_all() + page_size = get_page_size() + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + self.page_idx_pool = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index b22b7f8ab..fb5d564d6 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -5,7 +5,7 @@ from typing import List, Optional from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size from lightllm.utils.config_utils import get_vocab_size logger = init_logger(__name__) @@ -63,6 +63,14 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" ) mem_manager.req_to_token_indexs = self.req_to_token_indexs + if hasattr(mem_manager, "req_to_page_indexs"): + page_size = get_page_size() + self.req_to_page_indexs = torch.zeros( + (max_request_num + 1, (max_sequence_length + page_size - 1) // page_size), + dtype=torch.int32, + device="cuda", + ) + mem_manager.req_to_page_indexs = self.req_to_page_indexs self.mem_manager = mem_manager self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 98f628f07..1f249c199 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -3,12 +3,16 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.dist_utils import get_current_device_id from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index from lightllm.common.basemodel.batch_objs import ModelInput +def cdiv(a, b): + return (a + b - 1) // b + + class FlashAttentionStateInfo(LlamaInferStateInfo): _shared_page_table_buffer = None @@ -28,32 +32,34 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device - ) - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) + length = cdiv(self.max_seq_len, get_page_size()) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) + if "page_size_variable" in model.mode: + self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + else: + self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) else: # Meta information of flashattention for decoding self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch + page_size = get_page_size() + length = cdiv(model.graph_max_len_in_batch, page_size) + page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) + self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( + self.batch_size, length ) - self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device - ) + length = cdiv(self.max_len_in_batch, get_page_size()) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k], - non_blocking=True, - ) - self.page_table[:, max_seq_len_k:].fill_(0) + length = cdiv(max_seq_len_k, get_page_size()) + if "page_size_variable" in model.mode: + self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + else: + self.page_table[:, :length].copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) + self.page_table[:, length:].fill_(0) if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index b00215cff..6683df346 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -27,7 +27,7 @@ from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops @@ -87,6 +87,14 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) + elif "page_size_variable" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_context_attention_flashattention, self + ) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_token_decode_attention_flashattention, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: self._context_attention_kernel = partial( LlamaTransformerLayerInfer._context_attention_flashattention, self @@ -317,6 +325,39 @@ def _context_attention_kernel_ppl_int8kv( ) return o_tensor + def _paged_context_attention_flashattention( + self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + page_size = get_page_size() + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + -1, page_size, self.tp_k_head_num_, self.head_dim_ + ) + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) + q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=infer_state.q_max_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o + def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ @@ -824,6 +865,39 @@ def _token_decode_attention_gqa_flashdecoding_vsm( alloc_tensor_func=self.alloc_tensor, ) + def _paged_token_decode_attention_flashattention( + self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + page_size = get_page_size() + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + -1, page_size, self.tp_k_head_num_, self.head_dim_ + ) + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) + q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=sm_scale, + causal=False, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o + def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 000000000..33210b90d --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,462 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +from typing import Tuple, Dict, Set, List +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} # page_hash -> TreeNode + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置 + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() # 用于标识时间周期 + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.total_children_count = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, self.total_children_count, self.time_id) + + def _compute_key(self, tokens: torch.Tensor) -> int: + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else hash(page_tokens.cpu().numpy().tobytes()) + + def find_matched_child(self, token_id_key: torch.Tensor) -> Tuple["TreeNode", int]: + target_key = self._compute_key(token_id_key) + if target_key in self.children: + child = self.children[target_key] + prefix_len = match(token_id_key, child.token_id_key) + # 只匹配page_size的整数倍长度 + if self.page_size > 1: + if prefix_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + prefix_len = prefix_len & ~self._page_size_mask + else: + prefix_len = (prefix_len // self.page_size) * self.page_size + if prefix_len == 0: + return None, 0 + return child, prefix_len + + return None, 0 + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + self.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + + remaining_tokens = self.token_id_key[prefix_len:] + split_parent_node.children[self._compute_key(remaining_tokens)] = self + split_parent_node.ref_counter = self.ref_counter + split_parent_node.total_children_count = 1 + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = remaining_tokens + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + + self.children[self._compute_key(token_id_key)] = child + child.parent = self + self.total_children_count += 1 + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[self._compute_key(child_node.token_id_key)] + child_node.parent = None + self.total_children_count -= 1 + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return self.total_children_count == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + # Ensure same shape for comparison: flatten and get min length + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + + # Compare elements and find first mismatch + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len # All matched up to min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + """ + unique_name 主要用于解决单机,多实列部署时的shm冲突 + """ + + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + self.mem_manager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + # 预计算page_size相关的常量 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 # 初始化为 1 保证永远不会被 evict 掉 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) # 自定义比较器 + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _get_page_aligned_key(self, key, value=None): + aligned_len = len(key) + if aligned_len == 0: + return None, None + # page_size > 1时, 需要确保输入的key长度是page_size的整数倍 + if self.page_size > 1: + if aligned_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + aligned_len = aligned_len & ~self._page_size_mask + else: + aligned_len = (aligned_len // self.page_size) * self.page_size + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None): + if value is None: + value = key + + assert len(key) == len(value) # and len(key) >= 1 + key, value = self._get_page_aligned_key(key, value) + if key is None: + return 0 + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + try: + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + remaining_key = key[prefix_len:] + remaining_value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(remaining_key, remaining_value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0 + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update_refs=False) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + # from 0 to 1 need update refs token num + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + try: + if len(key) == 0: + return node + + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return self._match_prefix_helper(child, key[prefix_len:], ans_value_list, update_refs=update_refs) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + # from 0 to 1 need update refs token num + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + return node + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert node.ref_counter == 0 and node.is_leaf() and node != self.root_node, "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def assert_leafs_is_right(self): + for node in self.evict_tree_set: + if node.is_leaf() and node.ref_counter == 0: + a = node.token_mem_index_value.cuda() + assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) + + def clear_tree_nodes(self): + """ + 该函数只在测试时调用 + """ + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + # 如果减引用的是叶节点,需要先从 evict_tree_set 中移除 + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + # 加回。 + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token( + self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ): + assert self.mem_manager is not None + need_pages = 0 + can_use_pages = 0 + if hasattr(self.mem_manager, "can_use_page_size") and self.page_size > 1 and b_seq_len is not None: + + def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None, is_prefill=False): + need_new_pages = 0 + if is_prefill: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = (need_tokens_array + page_size - 1) // page_size + need_new_pages = need_pages_array.sum() + else: + mask = (b_seq_len - 1) % page_size == 0 + need_new_pages = mask.sum() + return need_new_pages + + need_pages = get_need_page_size(self.page_size, b_seq_len, b_ready_cache_len, is_prefill) + can_use_pages = self.mem_manager.can_use_page_size + if need_token_num > self.mem_manager.can_use_mem_size or need_pages > can_use_pages: + need_evict_single_token_num = need_token_num - self.mem_manager.can_use_mem_size + need_evict_page_token_num = (need_pages - can_use_pages) * self.page_size + need_evict_token_num = max(need_evict_single_token_num, need_evict_page_token_num) + remaining_tokens = self.get_tree_total_tokens_num() - self.get_refed_tokens_num() + need_evict_token_num = min(need_evict_token_num, remaining_tokens) + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + if release_mems: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return + + +class _RadixCacheReadOnlyClient: + """ + router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。 + """ + + def __init__(self, unique_name, total_token_num, rank_in_node): + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def get_unrefed_tokens_num(self): + return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] + + +class RadixCacheReadOnlyClient: + def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size): + self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [ + _RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node) + for rank_in_node in range(0, node_world_size, dp_world_size) + ] + + def get_refed_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num() + + def get_tree_total_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num() + + def get_unrefed_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num() diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 65ec4354b..a60d0a942 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -333,7 +333,9 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num): + def free_radix_cache_to_get_enough_token( + self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ): assert self.mem_manager is not None if need_token_num > self.mem_manager.can_use_mem_size: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 7c2311d56..99830a8db 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,6 +10,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -139,8 +140,9 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + radix_cache_class = PagedRadixCache if "page_size_variable" in self.mode else RadixCache self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 5059ad27f..448c0d987 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -77,7 +77,9 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len, True + ) mem_indexes = g_infer_context.req_manager.mem_manager.alloc( input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True ) @@ -164,7 +166,7 @@ def padded_prepare_decode_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num, b_seq_len) mem_indexes = g_infer_context.req_manager.mem_manager.alloc( b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 1c81ebfcd..e5e871d83 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -55,7 +55,9 @@ def prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + input_ids.shape[0], b_seq_len, b_ready_cache_len, True + ) mem_indexes = g_infer_context.req_manager.mem_manager.alloc( input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True ) @@ -113,7 +115,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0], b_seq_len) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 6c9a070e2..bd2752e2f 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -158,6 +158,15 @@ def set_triton_autotune_level(level: int): return +@lru_cache(maxsize=None) +def get_page_size(): + try: + args = get_env_start_args() + return int(os.getenv("PAGE_SIZE", 4)) if "page_size_variable" in args.mode else 1 + except: + return 1 + + g_model_init_done = False From 2968784f517c06ce1ccae8f5821e899a4ebe71eb Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 1 Aug 2025 18:21:07 +0800 Subject: [PATCH 3/9] feat: support page size variable for flashinfer --- lightllm/models/llama/flashinfer_struct.py | 89 +++++--- .../layer_infer/transformer_layer_infer.py | 52 ++++- lightllm/utils/envs_utils.py | 2 +- .../test_context_flashattention_nopad.py | 5 +- ..._context_flashattention_nopad_fa3_paged.py | 163 +++++++++++++ ...t_flashattention_nopad_flashinfer_paged.py | 214 ++++++++++++++++++ .../test_token_attention_nopad_fa3_paged.py | 186 +++++++++++++++ ..._token_attention_nopad_flashinfer_paged.py | 169 ++++++++++++++ 8 files changed, 843 insertions(+), 37 deletions(-) create mode 100644 unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py create mode 100644 unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py create mode 100644 unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py create mode 100644 unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index a0c40b57a..3b9a378c4 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -3,16 +3,21 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +def cdiv(a, b): + return (a + b - 1) // b + + class LlamaFlashInferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() self.prefill_wrapper = None self.decode_wrapper = None self.flashinfer_extra_state = None + self.page_size = get_page_size() def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) @@ -22,29 +27,41 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=input_ids.device - ) + self.kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length + : self.batch_size * length ] else: self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) self.kv_starts = self.b1_cu_kv_seq_len.int() + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + self.kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + self.kv_starts[:-1], + cdiv(self.max_kv_seq_len, self.page_size), + self.kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + self.b_start_loc, + self.max_kv_seq_len, + self.kv_indices, + ) if self.decode_wrapper is None: self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, @@ -53,16 +70,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): use_tensor_cores=True, paged_kv_indptr_buffer=self.kv_starts, paged_kv_indices_buffer=self.kv_indices, - paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + paged_kv_last_page_len_buffer=self.kv_last_page_len, ) self.decode_wrapper.plan( self.kv_starts, self.kv_indices, - self.kv_last_page_len_buffer, + self.kv_last_page_len, self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=self.flashinfer_extra_state.q_data_type, kv_data_type=self.flashinfer_extra_state.kv_data_type, non_blocking=True, @@ -72,19 +89,33 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - kv_starts[:-1], - self.max_kv_seq_len, - kv_indices, - ) + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + kv_starts[1:] = b_page_len.cumsum(0) + kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + kv_starts[:-1], + cdiv(self.max_kv_seq_len, self.page_size), + kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + kv_starts[:-1], + self.max_kv_seq_len, + kv_indices, + ) self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, qo_indptr_buf=q_starts, @@ -100,7 +131,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, causal=True, pos_encoding_mode="NONE", logits_soft_cap=0.0, @@ -115,11 +146,11 @@ def copy_for_cuda_graph(self, new_infer_state): self.decode_wrapper.plan( new_infer_state.kv_starts, new_infer_state.kv_indices, - new_infer_state.kv_last_page_len_buffer, + new_infer_state.kv_last_page_len, new_infer_state.flashinfer_extra_state.tp_q_head_num, new_infer_state.flashinfer_extra_state.tp_kv_head_num, new_infer_state.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=new_infer_state.flashinfer_extra_state.q_data_type, kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type, non_blocking=True, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 6683df346..31c5f02ba 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -107,9 +107,16 @@ def _bind_attention(self): raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") return elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self - ) + if "page_size_variable" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_context_attention_flashinfer_kernel, self + ) + elif not self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self + ) + else: + raise Exception(f"Unsupported mode for flashinfer backend: {self.mode}") else: self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) if "ppl_int8kv" in self.mode: @@ -174,6 +181,12 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) + elif "page_size_variable" in self.mode: + assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_token_decode_attention_flashinfer, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: if get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( @@ -274,6 +287,21 @@ def _context_attention_flashinfer_kernel( ) return o_tensor + def _paged_context_attention_flashinfer_kernel( + self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + page_size = get_page_size() + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) + infer_state.prefill_wrapper.run( + q.view(q.shape[0], -1, self.head_dim_), + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), + out=o_tensor.view(q.shape[0], -1, self.head_dim_), + ) + return o_tensor + def _context_attention_kernel( self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: @@ -587,6 +615,24 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat ) return o_tensor + def _paged_token_decode_attention_flashinfer( + self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None + ): + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) + + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + page_size = get_page_size() + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) + infer_state.decode_wrapper.run( + q.view(calcu_shape1), + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), + out=o_tensor.view(calcu_shape1), + ) + return o_tensor + def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index bd2752e2f..c3a9469f0 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -162,7 +162,7 @@ def set_triton_autotune_level(level: int): def get_page_size(): try: args = get_env_start_args() - return int(os.getenv("PAGE_SIZE", 4)) if "page_size_variable" in args.mode else 1 + return int(os.getenv("PAGE_SIZE", 64)) if "page_size_variable" in args.mode else 1 except: return 1 diff --git a/unit_tests/models/llama/test_context_flashattention_nopad.py b/unit_tests/models/llama/test_context_flashattention_nopad.py index f24ab619b..94e61cfda 100644 --- a/unit_tests/models/llama/test_context_flashattention_nopad.py +++ b/unit_tests/models/llama/test_context_flashattention_nopad.py @@ -10,7 +10,6 @@ context_attention_fwd_no_prompt_cache, ) from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -56,8 +55,6 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.batch_size = Z infer_state.max_len_in_batch = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len infer_state.b_ready_cache_len = b_ready_cache_len @@ -73,7 +70,7 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.b_seq_len, infer_state.b_ready_cache_len, infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, + req_to_token_indexs, ) batch_size = Z diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py new file mode 100644 index 000000000..e7702f084 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py @@ -0,0 +1,163 @@ +import torch +import time +import pytest +import triton as tl +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + if N_CTX > 1: + b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :] + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :] + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), + v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(f"cos_sim1: {cos_sim1}") + assert cos_sim1.item() == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(f"cos_sim2: {cos_sim2}") + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..763a80015 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py @@ -0,0 +1,214 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, + context_attention_fwd_no_prompt_cache, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + + total_pages = num_pages_per_seq.sum().item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + + # 设置kv_indices + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + qo_indptr_buf=q_indptr, + paged_kv_indptr_buf=kv_indptr, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len_buffer, + ) + + # 设置kv_last_page_len + kv_last_page_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=q.dtype, + kv_data_type=kv.dtype, + ) + k_cache = kv[:, :, :KV_HEADS, :] + v_cache = kv[:, :, KV_HEADS:, :] + wrapper.run(q, (k_cache, v_cache), out=o1, return_lse=False) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + q = torch.randn((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + k = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + v = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + + o = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + context_attention_fwd_no_prompt_cache( + q, + k, + v, + o, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, + ) + wrapper.plan( + qo_indptr=q_indptr, + kv_indptr=kv_indptr, + num_qo_heads=q_heads, + num_kv_heads=kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + q_data_type=dtype, + causal=True, + ) + wrapper.run(q, k, v, out=o1, return_lse=False) + + # assert torch.allclose(o, o1, atol=1e-2, rtol=0) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +if __name__ == "__main__": + test_context_attention_fwd(32, 16384, 32, 4, 128) # 16384 is divisible by 4 diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py new file mode 100644 index 000000000..1de2fbc34 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py @@ -0,0 +1,186 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd +from lightllm.utils.sgl_utils import flash_attn_with_kvcache + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_starts = torch.arange(0, Z + 1).int().cuda() + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :].contiguous() + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :].contiguous() + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.view(-1, 1, kv_heads, head_dim), + v_cache=v_cache.view(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1 == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(cos_sim2) + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..9bb97be99 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py @@ -0,0 +1,169 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = torch.arange(Z).cuda().int() * N_CTX + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + # gqa_decode_attention_fwd( + # q, + # kv[:,:KV_HEADS,:], + # kv[:,KV_HEADS:,:], + # o, + # req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_seq_len, + # ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_indptr = torch.zeros(Z + 1, dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + # Fill the paged KV data indices + total_pages = kv_indptr[-1].item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + # Calculate last page lengths + kv_last_page_len = torch.zeros(Z, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=kv_indptr, + paged_kv_indices_buffer=kv_indices, + paged_kv_last_page_len_buffer=kv_last_page_len_buffer, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + q_data_type=dtype, + non_blocking=True, + ) + wrapper.run(q, (kv[:, :, :KV_HEADS, :], kv[:, :, KV_HEADS:, :]), out=o1, return_lse=False) + cos_sim = F.cosine_similarity(o, o1).mean() + assert cos_sim == 1.0 + + +if __name__ == "__main__": + test_token_attention_nopad(32, 16384, 32, 4, 128) From f7678bf2ffd1ba77a95900dcce5b72632dbb8628 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Mon, 4 Aug 2025 16:35:15 +0800 Subject: [PATCH 4/9] feat: support page size variable for deepseek2 --- ...eepseek2_page_size_variable_mem_manager.py | 25 ++++++++ .../deepseek2/flashattention_infer_struct.py | 33 ++++++---- .../models/deepseek2/flashinfer_struct.py | 44 +++++++++---- .../layer_infer/transformer_layer_infer.py | 64 ++++++++++++++++++- lightllm/models/deepseek2/model.py | 5 ++ 5 files changed, 145 insertions(+), 26 deletions(-) create mode 100755 lightllm/common/deepseek2_page_size_variable_mem_manager.py diff --git a/lightllm/common/deepseek2_page_size_variable_mem_manager.py b/lightllm/common/deepseek2_page_size_variable_mem_manager.py new file mode 100755 index 000000000..6c3cd7014 --- /dev/null +++ b/lightllm/common/deepseek2_page_size_variable_mem_manager.py @@ -0,0 +1,25 @@ +import torch +import numpy as np +from .deepseek2_mem_manager import Deepseek2MemoryManager +from .page_size_variable_mem_manager import PageSizeVariableMemoryManager +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_page_size + + +def cdiv(a, b): + return (a + b - 1) // b + + +logger = init_logger(__name__) + + +class Deepseek2PageSizeVariableMemoryManager(PageSizeVariableMemoryManager, Deepseek2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, cdiv(size, get_page_size()) * get_page_size(), head_num, head_dim), + dtype=dtype, + device="cuda", + ) diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index d2ae055ce..bfdde792f 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -4,6 +4,11 @@ import torch.distributed as dist from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.envs_utils import get_page_size + + +def cdiv(a, b): + return (a + b - 1) // b class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): @@ -38,20 +43,24 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_q = self.b1_cu_q_seq_len self.cu_seqlens_k = self.b1_cu_kv_seq_len max_seq_len_k = self.max_kv_seq_len + page_size = get_page_size() if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch + length = cdiv(model.graph_max_len_in_batch, page_size) + page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) + self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( + self.batch_size, length ) - self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to( - input_ids.device - ) + length = cdiv(self.max_len_in_batch, page_size) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device) - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] - ) - self.page_table[:, max_seq_len_k:].fill_(0) + if "page_size_variable" in model.mode: + length = cdiv(max_seq_len_k, page_size) + self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + self.page_table[:, length:].fill_(0) + else: + self.page_table[:, :max_seq_len_k].copy_( + model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] + ) + self.page_table[:, max_seq_len_k:].fill_(0) return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index a00c45601..25ebe3889 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -3,16 +3,21 @@ import numpy as np import torch.distributed as dist from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +def cdiv(a, b): + return (a + b - 1) // b + + class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): def __init__(self): super().__init__() self.prefill_wrapper = None self.decode_wrapper = None self.flashinfer_extra_state = None + self.page_size = get_page_size() def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) @@ -23,24 +28,37 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length + : self.batch_size * length ] else: self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + self.kv_starts[:-1], + cdiv(self.max_len_in_batch, self.page_size), + self.kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + self.b_start_loc, + self.max_len_in_batch, + self.kv_indices, + ) if self.decode_wrapper is None: self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.flashinfer_extra_state.workspace_buffer, @@ -58,7 +76,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.kv_lora_rank, self.flashinfer_extra_state.qk_rope_head_dim, - 1, + self.page_size, False, # causal self.flashinfer_extra_state.softmax_scale, self.flashinfer_extra_state.q_data_type, @@ -97,7 +115,7 @@ def copy_for_cuda_graph(self, new_infer_state): new_infer_state.flashinfer_extra_state.tp_q_head_num, new_infer_state.flashinfer_extra_state.kv_lora_rank, new_infer_state.flashinfer_extra_state.qk_rope_head_dim, - 1, + self.page_size, False, # causal new_infer_state.flashinfer_extra_state.softmax_scale, new_infer_state.flashinfer_extra_state.q_data_type, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ace54bba4..ea0d58212 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -26,7 +26,7 @@ from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 @@ -93,6 +93,18 @@ def _bind_attention(self): self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self ) + elif "page_size_variable" in self.mode: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + if get_env_start_args().enable_fa3: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention_paged, self + ) + elif get_env_start_args().enable_flashinfer_decode: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer_paged, self + ) + else: + raise Exception("Page size variable mode is not supported in other backends.") else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) if get_env_start_args().enable_fa3: @@ -574,6 +586,36 @@ def _token_gqa_decode_attention_flashattention( ) return o_tensor + def _token_gqa_decode_attention_flashattention_paged( + self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ): + page_size = get_page_size() + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank) + k_descale, v_descale = None, None + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=self.softmax_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o_tensor + def _token_gqa_decode_attention_flashinfer( self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): @@ -593,6 +635,26 @@ def _token_gqa_decode_attention_flashinfer( ) return o_tensor + def _token_gqa_decode_attention_flashinfer_paged( + self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ): + page_size = get_page_size() + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) + + infer_state.decode_wrapper.run( + q_nope, + q_rope, + kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank), + kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim), + out=o_tensor, + return_lse=False, + ) + return o_tensor + def _token_gqa_decode_attention_flashdecoding( self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 9101cb963..c2380c4ab 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -10,6 +10,7 @@ from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager +from lightllm.common.deepseek2_page_size_variable_mem_manager import Deepseek2PageSizeVariableMemoryManager from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager from lightllm.utils.log_utils import init_logger from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale @@ -97,6 +98,10 @@ def _init_mem_manager(self): manager_class = Deepseek2MemoryManager if "triton_fp8kv" in self.mode: manager_class = Deepseek2FP8KVMemoryManager + elif "page_size_variable" in self.mode: + manager_class = Deepseek2PageSizeVariableMemoryManager + elif self.mode: + raise ValueError(f"Unsupported mode for deepseek2: {self.mode}") # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 From 6d227b38d81fa41a24d2e0260182d692b465da12 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 28 Aug 2025 14:19:34 +0800 Subject: [PATCH 5/9] fix: fix the page not enough bug --- lightllm/common/mem_manager.py | 10 ++++ .../common/page_size_variable_mem_manager.py | 19 +++++-- .../deepseek2/flashattention_infer_struct.py | 8 +-- .../layer_infer/transformer_layer_infer.py | 12 ++--- .../llama/flashattention_infer_struct.py | 10 ++-- .../layer_infer/transformer_layer_infer.py | 18 +++---- lightllm/server/api_cli.py | 4 +- lightllm/server/api_start.py | 7 +++ .../dynamic_prompt/paged_radix_cache.py | 49 ++++--------------- .../router/req_queue/chunked_prefill/impl.py | 5 +- .../benchmark/static_inference/model_infer.py | 4 +- 11 files changed, 71 insertions(+), 75 deletions(-) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index f3cf0419d..8ba047ee8 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -52,6 +52,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size + # MemoryManager也需要个引用备份,供内部使用 self.req_to_token_indexs = None def get_cell_size(self): @@ -341,8 +342,17 @@ def __init__(self) -> None: SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] + self.shared_tp_info_pages = [ + SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}") + for rank_in_node in range(0, self.node_world_size, self.dp_world_size) + ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: return self.shared_tp_infos[0].get_value() return self.shared_tp_infos[dp_rank_in_node].get_value() + + def get_unrefed_page_num(self, dp_rank_in_node: int): + if self.is_multinode_tp: + return self.shared_tp_info_pages[0].get_value() + return self.shared_tp_info_pages[dp_rank_in_node].get_value() diff --git a/lightllm/common/page_size_variable_mem_manager.py b/lightllm/common/page_size_variable_mem_manager.py index 095648c01..8456f2902 100755 --- a/lightllm/common/page_size_variable_mem_manager.py +++ b/lightllm/common/page_size_variable_mem_manager.py @@ -3,7 +3,9 @@ from .mem_manager import MemoryManager from typing import List, Union from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_page_size +from lightllm.utils.envs_utils import get_unique_server_name, get_page_size +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.dist_utils import get_current_rank_in_node def cdiv(a, b): @@ -24,6 +26,12 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.mark_page_start = 0 self.can_use_page_size = cdiv(self.size, page_size) + rank_in_node = get_current_rank_in_node() + self.shared_can_use_page_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}" + ) + self.shared_can_use_page_num.set_value(self.can_use_page_size) + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty( (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), @@ -141,6 +149,7 @@ def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_pref token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill) self.can_use_mem_size -= need_size self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_page_num.set_value(self.can_use_page_size) return token_idxs def free(self, free_index: Union[torch.Tensor, List[int]]): @@ -154,12 +163,13 @@ def free(self, free_index: Union[torch.Tensor, List[int]]): if len(free_index) == 0: return - page_indices = free_index // page_size - unique_pages = torch.unique(page_indices) - for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序 + base_free_index = free_index[free_index % page_size == 0] + page_indices = base_free_index // page_size + for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序 self.mark_page_start -= 1 self.page_idx_pool[self.mark_page_start] = page_idx self.can_use_page_size += 1 + self.shared_can_use_page_num.set_value(self.can_use_page_size) return @@ -168,6 +178,7 @@ def free_all(self): page_size = get_page_size() self.mark_page_start = 0 self.can_use_page_size = cdiv(self.size, page_size) + self.shared_can_use_page_num.set_value(self.can_use_page_size) self.page_idx_pool = torch.arange( 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True ) diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index bfdde792f..52ba3beb4 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -16,6 +16,7 @@ class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): def __init__(self): super().__init__() + self.page_size = get_page_size() @classmethod def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): @@ -43,19 +44,18 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_q = self.b1_cu_q_seq_len self.cu_seqlens_k = self.b1_cu_kv_seq_len max_seq_len_k = self.max_kv_seq_len - page_size = get_page_size() if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - length = cdiv(model.graph_max_len_in_batch, page_size) + length = cdiv(model.graph_max_len_in_batch, self.page_size) page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( self.batch_size, length ) else: - length = cdiv(self.max_len_in_batch, page_size) + length = cdiv(self.max_len_in_batch, self.page_size) self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device) if "page_size_variable" in model.mode: - length = cdiv(max_seq_len_k, page_size) + length = cdiv(max_seq_len_k, self.page_size) self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) self.page_table[:, length:].fill_(0) else: diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ea0d58212..ea2a80336 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -26,7 +26,7 @@ from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor -from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 @@ -589,12 +589,11 @@ def _token_gqa_decode_attention_flashattention( def _token_gqa_decode_attention_flashattention_paged( self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): - page_size = get_page_size() q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank) + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank) k_descale, v_descale = None, None o_tensor = flash_attn_with_kvcache( q=q_rope, @@ -638,7 +637,6 @@ def _token_gqa_decode_attention_flashinfer( def _token_gqa_decode_attention_flashinfer_paged( self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): - page_size = get_page_size() q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) @@ -648,8 +646,8 @@ def _token_gqa_decode_attention_flashinfer_paged( infer_state.decode_wrapper.run( q_nope, q_rope, - kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank), - kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim), + kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank), + kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim), out=o_tensor, return_lse=False, ) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 1f249c199..28611e901 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -18,6 +18,7 @@ class FlashAttentionStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() + self.page_size = get_page_size() @classmethod def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): @@ -32,7 +33,7 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - length = cdiv(self.max_seq_len, get_page_size()) + length = cdiv(self.max_seq_len, self.page_size) self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) if "page_size_variable" in model.mode: self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) @@ -44,17 +45,16 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_size = get_page_size() - length = cdiv(model.graph_max_len_in_batch, page_size) + length = cdiv(model.graph_max_len_in_batch, self.page_size) page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( self.batch_size, length ) else: - length = cdiv(self.max_len_in_batch, get_page_size()) + length = cdiv(self.max_len_in_batch, self.page_size) self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) - length = cdiv(max_seq_len_k, get_page_size()) + length = cdiv(max_seq_len_k, self.page_size) if "page_size_variable" in model.mode: self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) else: diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 31c5f02ba..66a169996 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -27,7 +27,7 @@ from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops @@ -291,9 +291,8 @@ def _paged_context_attention_flashinfer_kernel( self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - page_size = get_page_size() kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( - -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ ) infer_state.prefill_wrapper.run( q.view(q.shape[0], -1, self.head_dim_), @@ -356,13 +355,12 @@ def _context_attention_kernel_ppl_int8kv( def _paged_context_attention_flashattention( self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None ): - page_size = get_page_size() cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, page_size, self.tp_k_head_num_, self.head_dim_ + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ ) cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] @@ -622,9 +620,8 @@ def _paged_token_decode_attention_flashinfer( calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - page_size = get_page_size() kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( - -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ ) infer_state.decode_wrapper.run( q.view(calcu_shape1), @@ -914,13 +911,12 @@ def _token_decode_attention_gqa_flashdecoding_vsm( def _paged_token_decode_attention_flashattention( self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None ): - page_size = get_page_size() cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, page_size, self.tp_k_head_num_, self.head_dim_ + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ ) cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ae9f7541d..b52d72d14 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -179,7 +179,7 @@ def make_argument_parser() -> argparse.ArgumentParser: nargs="+", help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv - | export_fp8kv_calibration + | export_fp8kv_calibration | page_size_variable triton_flashdecoding mode is for long context, current support llama llama2 qwen; triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; @@ -191,6 +191,8 @@ def make_argument_parser() -> argparse.ArgumentParser: Calibration need to disable cudagraph and use fa3 or flashinfer backend. ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; ppl_fp16 mode use ppl fast fp16 decode attention kernel; + page_size_variable allow to use page size > 1, use PAGE_SIZE env to set page size, + page_size_variable only support fa3 and flashinfer backend for now you need to read source code to make sure the supported detail mode for all models""", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 03c519d7b..6b89e8f1a 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -126,6 +126,13 @@ def normal_or_p_d_start(args): "--enable_flashinfer_prefill and --enable_flashinfer_decode" ) assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" + if "page_size_variable" in args.mode: + assert args.enable_fa3 is True or ( + args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True + ), ( + "page_size_variable mode need enable fa3 or flashinfer, add --enable_fa3 or " + "--enable_flashinfer_prefill and --enable_flashinfer_decode" + ) # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py index 33210b90d..687fa1a22 100644 --- a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -159,7 +159,7 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.tree_total_tokens_num.arr[0] = 0 - def _get_page_aligned_key(self, key, value=None): + def _get_page_aligned_key(self, key, value=None, free_truncated=False): aligned_len = len(key) if aligned_len == 0: return None, None @@ -171,6 +171,13 @@ def _get_page_aligned_key(self, key, value=None): aligned_len = aligned_len & ~self._page_size_mask else: aligned_len = (aligned_len // self.page_size) * self.page_size + + # 释放被截断的部分 + if free_truncated and aligned_len < len(key) and self.mem_manager is not None: + truncated_value = value[aligned_len:] if value is not None else key[aligned_len:] + if len(truncated_value) > 0: + self.mem_manager.free(truncated_value) + return ( key[:aligned_len] if aligned_len > 0 else None, value[:aligned_len] if value is not None and aligned_len > 0 else None, @@ -182,7 +189,7 @@ def insert(self, key, value=None): value = key assert len(key) == len(value) # and len(key) >= 1 - key, value = self._get_page_aligned_key(key, value) + key, value = self._get_page_aligned_key(key, value, free_truncated=True) if key is None: return 0 return self._insert_helper(self.root_node, key, value) @@ -422,41 +429,3 @@ def release_mem(mem_index): mem_index = torch.concat(release_mems) self.mem_manager.free(mem_index) return - - -class _RadixCacheReadOnlyClient: - """ - router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。 - """ - - def __init__(self, unique_name, total_token_num, rank_in_node): - self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) - self.tree_total_tokens_num = SharedArray( - f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 - ) - - def get_refed_tokens_num(self): - return self.refed_tokens_num.arr[0] - - def get_tree_total_tokens_num(self): - return self.tree_total_tokens_num.arr[0] - - def get_unrefed_tokens_num(self): - return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] - - -class RadixCacheReadOnlyClient: - def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size): - self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [ - _RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node) - for rank_in_node in range(0, node_world_size, dp_world_size) - ] - - def get_refed_tokens_num(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num() - - def get_tree_total_tokens_num(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num() - - def get_unrefed_tokens_num(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num() diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index f1dae4cac..8fa6248b3 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -3,6 +3,7 @@ from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.utils.envs_utils import get_page_size class ChunkedPrefillQueue(BaseQueue): @@ -32,9 +33,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() with g_router_lock.obj: + page_size = get_page_size() + page_remaining = (len(self.cache_len_list) - 1) * page_size if page_size > 1 else 0 ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens + < self.max_total_tokens - page_remaining ) ok_req_num = len(self.cache_len_list) <= self.running_max_req_size diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 58a2c44a0..b2c767c5d 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -260,7 +260,7 @@ def run_forward_once( total_token_num = batch_size * input_len mem_indexes = model_part.req_manager.mem_manager.alloc( test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True - ).cuda() + ) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") rank_id = model_kvargs["rank_id"] @@ -323,7 +323,7 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len) max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, From e70c85a7f942665f61ac9a1e66b436ee1a3881f6 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Tue, 2 Sep 2025 17:37:29 +0800 Subject: [PATCH 6/9] feat: add alloc_paged_token_indices function in req_manager --- lightllm/common/basemodel/basemodel.py | 8 +- lightllm/common/basemodel/cuda_graph.py | 4 +- ...ager.py => deepseek2_paged_mem_manager.py} | 4 +- lightllm/common/mem_manager.py | 18 +- lightllm/common/mem_utils.py | 4 +- .../common/page_size_variable_mem_manager.py | 184 ------------------ lightllm/common/paged_mem_manager.py | 94 +++++++++ lightllm/common/req_manager.py | 71 ++++++- .../deepseek2/flashattention_infer_struct.py | 5 +- .../models/deepseek2/flashinfer_struct.py | 7 +- lightllm/models/deepseek2/model.py | 4 +- .../triton_kernel/repack_kv_index.py | 112 ++++++++++- .../llama/flashattention_infer_struct.py | 10 +- lightllm/models/llama/flashinfer_struct.py | 12 +- .../dynamic_prompt/paged_radix_cache.py | 10 +- .../router/dynamic_prompt/radix_cache.py | 4 +- .../server/router/model_infer/infer_batch.py | 6 +- .../generic_padded_pre_process.py | 9 +- .../mode_backend/generic_pre_process.py | 8 +- .../benchmark/static_inference/model_infer.py | 6 +- .../static_inference/model_infer_mtp.py | 6 +- 21 files changed, 324 insertions(+), 262 deletions(-) rename lightllm/common/{deepseek2_page_size_variable_mem_manager.py => deepseek2_paged_mem_manager.py} (80%) delete mode 100755 lightllm/common/page_size_variable_mem_manager.py create mode 100755 lightllm/common/paged_mem_manager.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index b746e322c..6aba997c8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -687,8 +687,8 @@ def _check_max_len_infer(self): b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc( - len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True + mem_indexes = self.req_manager.alloc_token_indices( + len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len ).cuda() total_token_num = self.batch_max_tokens b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") @@ -759,12 +759,14 @@ def _autotune_warmup(self): 0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen ) b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = input_len b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") total_token_num = input_len b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") + mem_indexes = self.req_manager.alloc_token_indices( + len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len + ).cuda() model_input = ModelInput( batch_size=1, total_token_num=total_token_num, diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index db77298c7..4dabad0ff 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -202,7 +202,7 @@ def warmup(self, model): b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() + mem_indexes = model.req_manager.alloc_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda() model_input = ModelInput( batch_size=batch_size, @@ -258,7 +258,7 @@ def warmup_overlap(self, model): b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() + mem_indexes = model.req_manager.alloc_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/deepseek2_page_size_variable_mem_manager.py b/lightllm/common/deepseek2_paged_mem_manager.py similarity index 80% rename from lightllm/common/deepseek2_page_size_variable_mem_manager.py rename to lightllm/common/deepseek2_paged_mem_manager.py index 6c3cd7014..6067a7fb1 100755 --- a/lightllm/common/deepseek2_page_size_variable_mem_manager.py +++ b/lightllm/common/deepseek2_paged_mem_manager.py @@ -1,7 +1,7 @@ import torch import numpy as np from .deepseek2_mem_manager import Deepseek2MemoryManager -from .page_size_variable_mem_manager import PageSizeVariableMemoryManager +from .paged_mem_manager import PagedMemoryManager from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_page_size @@ -13,7 +13,7 @@ def cdiv(a, b): logger = init_logger(__name__) -class Deepseek2PageSizeVariableMemoryManager(PageSizeVariableMemoryManager, Deepseek2MemoryManager): +class Deepseek2PagedMemoryManager(PagedMemoryManager, Deepseek2MemoryManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 8ba047ee8..4142ce4aa 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -52,8 +52,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size - # MemoryManager也需要个引用备份,供内部使用 - self.req_to_token_indexs = None def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -245,9 +243,7 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc( - self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False - ) -> torch.Tensor: + def alloc(self, need_size) -> torch.Tensor: if need_size > self.mark_end - self.mark_start: logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") assert False, "error alloc state" @@ -261,9 +257,6 @@ def alloc( self.shared_can_use_token_num.set_value(self.can_use_mem_size) return ans - def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): - self.req_to_token_indexs[req_idx, start:end] = values - def free(self, free_index: Union[torch.Tensor, List[int]]): """_summary_ @@ -342,17 +335,8 @@ def __init__(self) -> None: SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] - self.shared_tp_info_pages = [ - SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}") - for rank_in_node in range(0, self.node_world_size, self.dp_world_size) - ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: return self.shared_tp_infos[0].get_value() return self.shared_tp_infos[dp_rank_in_node].get_value() - - def get_unrefed_page_num(self, dp_rank_in_node: int): - if self.is_multinode_tp: - return self.shared_tp_info_pages[0].get_value() - return self.shared_tp_info_pages[dp_rank_in_node].get_value() diff --git a/lightllm/common/mem_utils.py b/lightllm/common/mem_utils.py index 5f3ee6164..aab7adc08 100644 --- a/lightllm/common/mem_utils.py +++ b/lightllm/common/mem_utils.py @@ -4,7 +4,7 @@ from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager -from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager +from lightllm.common.paged_mem_manager import PagedMemoryManager from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -30,7 +30,7 @@ def select_mem_manager_class(mode): memory_manager_class = ExportCalibrationMemoryManager logger.info("Using mode export fp8kv calibration") elif "page_size_variable" in mode: - memory_manager_class = PageSizeVariableMemoryManager + memory_manager_class = PagedMemoryManager logger.info("Page size will be variable") else: memory_manager_class = MemoryManager diff --git a/lightllm/common/page_size_variable_mem_manager.py b/lightllm/common/page_size_variable_mem_manager.py deleted file mode 100755 index 8456f2902..000000000 --- a/lightllm/common/page_size_variable_mem_manager.py +++ /dev/null @@ -1,184 +0,0 @@ -import torch -import numpy as np -from .mem_manager import MemoryManager -from typing import List, Union -from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_unique_server_name, get_page_size -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from lightllm.utils.dist_utils import get_current_rank_in_node - - -def cdiv(a, b): - return (a + b - 1) // b - - -logger = init_logger(__name__) - - -class PageSizeVariableMemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - self.req_to_page_indexs = None - page_size = get_page_size() - self.page_idx_pool = torch.arange( - 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_page_start = 0 - self.can_use_page_size = cdiv(self.size, page_size) - - rank_in_node = get_current_rank_in_node() - self.shared_can_use_page_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}" - ) - self.shared_can_use_page_num.set_value(self.can_use_page_size) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty( - (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), - dtype=dtype, - device="cuda", - ) - - # 要求长度必须是page_size的整数倍,page内token索引必须连续 - def check_cache_page_valid(self, values: torch.Tensor): - end = len(values) - assert end % self.page_size == 0, "Values length must be a multiple of page size" - total_pages = end // self.page_size - for page_idx in range(total_pages): - values_start = page_idx * self.page_size - values_end = min((page_idx + 1) * self.page_size, end) - page_token_idxs = values[values_start:values_end] - if len(page_token_idxs) > 1: - expected_idxs = torch.arange( - page_token_idxs[0], - page_token_idxs[0] + len(page_token_idxs), - dtype=page_token_idxs.dtype, - device=page_token_idxs.device, - ) - if not torch.equal(page_token_idxs, expected_idxs): - return False - return True - - def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): - # assert self.check_cache_page_valid(values), "Values must be valid for page size" - page_size = get_page_size() - self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size - self.req_to_token_indexs[req_idx, start:end] = values - - def expand_by_page_size(self, b_token_len, page_size): - # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 - b_page_len = cdiv(b_token_len, page_size) - need_pages_num = b_page_len.sum() - p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) - cumsum_pages = torch.cumsum(b_page_len, dim=0) - last_page_positions = cumsum_pages - 1 - remainders = b_token_len - (b_page_len - 1) * page_size - p_token_len[last_page_positions] = remainders - return need_pages_num, b_page_len, p_token_len - - def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill): - if is_prefill: - b_req_idx = b_req_idx.cuda() - b_seq_len = b_seq_len.cuda() - b_ready_cache_len = b_ready_cache_len.cuda() - - b_token_len = b_seq_len - b_ready_cache_len - total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size) - if self.can_use_page_size < total_pages_needed: - raise RuntimeError( - f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}" - ) - - allocated_pages = self.page_idx_pool[ - self.mark_page_start : self.mark_page_start + total_pages_needed - ].cuda() - - def get_offsets_by_length(b_len, max_len): - # 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4] - offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device) - offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1) - return torch.masked_select(offsets, offset_mask) - - page_offsets = get_offsets_by_length(b_page_len, b_page_len.max()) - token_offsets = get_offsets_by_length(p_token_len, page_size) - - # 更新req_to_page_indexs, b_ready_cache_len必整除page_size - page_starts = b_ready_cache_len // page_size - req_id = torch.repeat_interleave( - torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len - ) - self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages - - self.mark_page_start += total_pages_needed - self.can_use_page_size -= total_pages_needed - page_bases = allocated_pages * page_size - return torch.repeat_interleave(page_bases, p_token_len) + token_offsets - else: - b_seq_len = b_seq_len.cuda() - b_req_idx = b_req_idx.cuda() - need_new_page_mask = (b_seq_len - 1) % page_size == 0 - new_pages_num = need_new_page_mask.sum() - if self.can_use_page_size < new_pages_num: - raise RuntimeError( - f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}" - ) - - token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) - if new_pages_num > 0: - new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda() - self.mark_page_start += new_pages_num - self.can_use_page_size -= new_pages_num - token_idxs[need_new_page_mask] = new_pages * page_size - - # 需要更新req_to_page_indexs - new_page_req_indices = b_req_idx[need_new_page_mask] - page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size - self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages - - mask = ~need_new_page_mask - if mask.any(): - seq_lens = b_seq_len[mask] - token_idxs[mask] = ( - self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size - + (seq_lens - 1) % page_size - ) - return token_idxs - - def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor: - page_size = get_page_size() - token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill) - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.shared_can_use_page_num.set_value(self.can_use_page_size) - return token_idxs - - def free(self, free_index: Union[torch.Tensor, List[int]]): - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - page_size = get_page_size() - if isinstance(free_index, list): - free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True) - - if len(free_index) == 0: - return - - base_free_index = free_index[free_index % page_size == 0] - page_indices = base_free_index // page_size - for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序 - self.mark_page_start -= 1 - self.page_idx_pool[self.mark_page_start] = page_idx - self.can_use_page_size += 1 - self.shared_can_use_page_num.set_value(self.can_use_page_size) - - return - - def free_all(self): - super().free_all() - page_size = get_page_size() - self.mark_page_start = 0 - self.can_use_page_size = cdiv(self.size, page_size) - self.shared_can_use_page_num.set_value(self.can_use_page_size) - self.page_idx_pool = torch.arange( - 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) diff --git a/lightllm/common/paged_mem_manager.py b/lightllm/common/paged_mem_manager.py new file mode 100755 index 000000000..dc97eae6e --- /dev/null +++ b/lightllm/common/paged_mem_manager.py @@ -0,0 +1,94 @@ +import torch +import numpy as np +from .mem_manager import MemoryManager +from typing import List, Union +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_unique_server_name, get_page_size +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.dist_utils import get_current_rank_in_node + + +def cdiv(a, b): + return (a + b - 1) // b + + +logger = init_logger(__name__) + + +class PagedMemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + page_size = get_page_size() + self.mem_page_state = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), + dtype=dtype, + device="cuda", + ) + + # 要求长度必须是page_size的整数倍,page内token索引必须连续 + def check_cache_page_valid(self, values: torch.Tensor): + end = len(values) + assert end % self.page_size == 0, "Values length must be a multiple of page size" + total_pages = end // self.page_size + for page_idx in range(total_pages): + values_start = page_idx * self.page_size + values_end = min((page_idx + 1) * self.page_size, end) + page_token_idxs = values[values_start:values_end] + if len(page_token_idxs) > 1: + expected_idxs = torch.arange( + page_token_idxs[0], + page_token_idxs[0] + len(page_token_idxs), + dtype=page_token_idxs.dtype, + device=page_token_idxs.device, + ) + if not torch.equal(page_token_idxs, expected_idxs): + return False + return True + + def alloc(self, need_size) -> torch.Tensor: + if self.can_use_page_size < need_size: + raise RuntimeError( + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {need_size}" + ) + new_pages = self.mem_page_state[self.mark_page_start : self.mark_page_start + need_size].cuda() + self.mark_page_start += need_size + self.can_use_page_size -= need_size + self.can_use_mem_size -= need_size * get_page_size() + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return new_pages + + def free(self, free_index: Union[torch.Tensor, List[int]]): + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + page_size = get_page_size() + if isinstance(free_index, list): + free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True) + + if len(free_index) == 0: + return + + base_free_index = free_index[free_index % page_size == 0] + page_indices = base_free_index // page_size + for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序 + self.mark_page_start -= 1 + self.mem_page_state[self.mark_page_start] = page_idx + self.can_use_page_size += 1 + + return + + def free_all(self): + super().free_all() + page_size = get_page_size() + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + self.mem_page_state = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index fb5d564d6..5435d3817 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -11,6 +11,10 @@ logger = init_logger(__name__) +def cdiv(a, b): + return (a + b - 1) // b + + class _ReqNode: def __init__(self, index): self.index = index @@ -62,20 +66,69 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_to_token_indexs = torch.zeros( (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" ) - mem_manager.req_to_token_indexs = self.req_to_token_indexs - if hasattr(mem_manager, "req_to_page_indexs"): - page_size = get_page_size() - self.req_to_page_indexs = torch.zeros( - (max_request_num + 1, (max_sequence_length + page_size - 1) // page_size), - dtype=torch.int32, - device="cuda", - ) - mem_manager.req_to_page_indexs = self.req_to_page_indexs self.mem_manager = mem_manager self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + def expand_by_page_size(self, b_token_len, page_size): + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 + b_page_len = cdiv(b_token_len, page_size) + need_pages_num = b_page_len.sum() + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, b_page_len, p_token_len + + def alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): + if b_ready_cache_len is not None: + # prefill + b_req_idx = b_req_idx.cuda() + b_seq_len = b_seq_len.cuda() + b_ready_cache_len = b_ready_cache_len.cuda() + + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size) + allocated_pages = self.mem_manager.alloc(total_pages_needed) + + def get_offsets_by_length(b_len, max_len): + # 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4] + offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device) + offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1) + return torch.masked_select(offsets, offset_mask) + + token_offsets = get_offsets_by_length(p_token_len, page_size) + page_bases = allocated_pages * page_size + return torch.repeat_interleave(page_bases, p_token_len) + token_offsets + else: + # decode + b_seq_len = b_seq_len.cuda() + b_req_idx = b_req_idx.cuda() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = need_new_page_mask.sum() + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages = self.mem_manager.alloc(new_pages_num) + token_idxs[need_new_page_mask] = new_pages * page_size + + mask = ~need_new_page_mask + if mask.any(): + seq_lens = b_seq_len[mask] + token_idxs[mask] = ( + self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size + + (seq_lens - 1) % page_size + ) + return token_idxs + + def alloc_token_indices(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None) -> torch.Tensor: + page_size = get_page_size() + if page_size > 1: + return self.alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) + else: + return self.mem_manager.alloc(need_size) + def alloc(self): return self.req_list.alloc() diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index 52ba3beb4..5355990f1 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -56,7 +56,10 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if "page_size_variable" in model.mode: length = cdiv(max_seq_len_k, self.page_size) - self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table[:, :length].copy_(token_indexs // self.page_size) self.page_table[:, length:].fill_(0) else: self.page_table[:, :max_seq_len_k].copy_( diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index 25ebe3889..1cb007377 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -4,7 +4,7 @@ import torch.distributed as dist from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.envs_utils import get_env_start_args, get_page_size -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index, repack_paged_kv_index_from_tokens def cdiv(a, b): @@ -42,12 +42,13 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if "page_size_variable" in model.mode: b_page_len = cdiv(self.b_seq_len, self.page_size) self.kv_starts[1:] = b_page_len.cumsum(0) - repack_kv_index( - self.req_manager.req_to_page_indexs, + repack_paged_kv_index_from_tokens( + self.req_manager.req_to_token_indexs, self.b_req_idx, b_page_len, self.kv_starts[:-1], cdiv(self.max_len_in_batch, self.page_size), + self.page_size, self.kv_indices, ) else: diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index c2380c4ab..50771e015 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -10,7 +10,7 @@ from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.common.deepseek2_page_size_variable_mem_manager import Deepseek2PageSizeVariableMemoryManager +from lightllm.common.deepseek2_paged_mem_manager import Deepseek2PagedMemoryManager from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager from lightllm.utils.log_utils import init_logger from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale @@ -99,7 +99,7 @@ def _init_mem_manager(self): if "triton_fp8kv" in self.mode: manager_class = Deepseek2FP8KVMemoryManager elif "page_size_variable" in self.mode: - manager_class = Deepseek2PageSizeVariableMemoryManager + manager_class = Deepseek2PagedMemoryManager elif self.mode: raise ValueError(f"Unsupported mode for deepseek2: {self.mode}") diff --git a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py index e86d2e819..4cb7012b2 100644 --- a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py +++ b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py @@ -33,6 +33,42 @@ def _fwd_kernel_repack_kv_index( return +@triton.jit +def _fwd_kernel_repack_page_kv_index_from_tokens( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + token_stride_h, + SEQ_BLOCK: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_seq_n = tl.program_id(1) + + cur_batch_seq_len = tl.load(seq_len + cur_batch) + cur_batch_req_idx = tl.load(req_index + cur_batch) + cur_batch_start_loc = tl.load(start_loc + cur_batch) + + offs_seq = (start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)) * page_size + block_end_loc = (tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len)) * page_size + token_data = tl.load( + req_to_token_indexs + token_stride_h * cur_batch_req_idx + offs_seq, + mask=offs_seq < block_end_loc, + other=0, + ) + valid_mask = (token_data % page_size) == 0 + valid_mask = valid_mask & (token_data > 0) # 确保是有效的 token 索引 + page_data = tl.where(valid_mask, token_data // page_size, 0) + + offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) + out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq + tl.store(out_kv_index_ptr, page_data, mask=offs_seq < block_end_loc) + return + + @torch.no_grad() def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): batch_size = req_index.shape[0] @@ -58,6 +94,51 @@ def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv return +@torch.no_grad() +def repack_paged_kv_index_from_tokens( + req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index +): + batch_size = req_index.shape[0] + out_kv_index.zero_() + + BLOCK = 64 + grid = ( + batch_size, + triton.cdiv(max_seq_len, BLOCK), + ) + + _fwd_kernel_repack_page_kv_index_from_tokens[grid]( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + req_to_token_indexs.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + return + + +def ref_repack_page_kv_index_with_token_input( + req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index +): + page_indexs = torch.zeros_like(req_to_token_indexs) + valid_mask = req_to_token_indexs % page_size == 0 + batch_size, seq_len_dim = req_to_token_indexs.shape + valid_positions = torch.cumsum(valid_mask.int(), dim=1) - 1 + batch_indices = torch.arange(batch_size, device=req_to_token_indexs.device).unsqueeze(1).expand(-1, seq_len_dim) + page_indexs.view(-1).scatter_add_( + 0, + (batch_indices * seq_len_dim + torch.where(valid_mask, valid_positions, 0)).flatten(), + (torch.where(valid_mask, req_to_token_indexs // page_size, 0) * valid_mask.int()).flatten(), + ) + + repack_kv_index(page_indexs, req_index, seq_len, start_loc, max_seq_len, out_kv_index) + + def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): output[start : start + sl] = req_to_token_indexs[b][:sl] @@ -67,6 +148,7 @@ def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output import torch.nn.functional as F BATCH, MAX_SEQ_LEN = 10, 1024 + PAGE_SIZE = 64 rand_idx = torch.randperm(2 * MAX_SEQ_LEN * BATCH).cuda().int() b_req_idx = torch.randperm(BATCH).cuda().int() b_seq_len = torch.randint(1, MAX_SEQ_LEN, (BATCH,)).cuda().int() @@ -77,14 +159,38 @@ def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output .int() ) + # 为每个batch生成基于page的连续索引 + for b in range(2 * BATCH): + start_page_id = b * 100 # 确保不同batch有不同的page ID范围 + for token_idx in range(2 * MAX_SEQ_LEN): + page_offset = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + page_id = start_page_id + page_offset + token_index = page_id * PAGE_SIZE + token_in_page + req_to_token_indexs[b, token_idx] = token_index + output = torch.zeros((b_seq_len.sum(),)).cuda().int() ref = torch.zeros((b_seq_len.sum(),)).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - req_to_token_indexs[b][:sl] = rand_idx[start : start + sl] fn1 = lambda: repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) fn2 = lambda: repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) ms1 = triton.testing.do_bench(fn1) ms2 = triton.testing.do_bench_cudagraph(fn2) - print(ms1, ms2) + print(f"repack_kv_index: ref={ms1:.3f}ms, triton={ms2:.3f}ms") assert torch.allclose(output.float(), ref.float()) + + b_page_len = triton.cdiv(b_seq_len, PAGE_SIZE) + page_output = torch.zeros((b_page_len.sum(),)).cuda().int() + page_ref = torch.zeros((b_page_len.sum(),)).cuda().int() + b_start_loc[1:] = b_page_len.cumsum(0)[:-1] + max_seq_len = triton.cdiv(MAX_SEQ_LEN, PAGE_SIZE) + fn3 = lambda: ref_repack_page_kv_index_with_token_input( + req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_seq_len, PAGE_SIZE, page_ref + ) + fn4 = lambda: repack_paged_kv_index_from_tokens( + req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_seq_len, PAGE_SIZE, page_output + ) + ms3 = triton.testing.do_bench(fn3) + ms4 = triton.testing.do_bench_cudagraph(fn4) + print(f"repack_paged_kv_index_from_tokens: ref={ms3:.3f}ms, triton={ms4:.3f}ms") + assert torch.allclose(page_output.float(), page_ref.float()) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 28611e901..51c302b62 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -36,7 +36,10 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): length = cdiv(self.max_seq_len, self.page_size) self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) if "page_size_variable" in model.mode: - self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table.copy_(token_indexs // self.page_size) else: self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) else: @@ -56,7 +59,10 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): length = cdiv(max_seq_len_k, self.page_size) if "page_size_variable" in model.mode: - self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table[:, :length].copy_(token_indexs // self.page_size) else: self.page_table[:, :length].copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) self.page_table[:, length:].fill_(0) diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index 3b9a378c4..4c696186f 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -4,7 +4,7 @@ import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.utils.envs_utils import get_env_start_args, get_page_size -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index, repack_paged_kv_index_from_tokens def cdiv(a, b): @@ -45,12 +45,13 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): b_page_len = cdiv(self.b_seq_len, self.page_size) self.kv_starts[1:] = b_page_len.cumsum(0) self.kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size - repack_kv_index( - self.req_manager.req_to_page_indexs, + repack_paged_kv_index_from_tokens( + self.req_manager.req_to_token_indexs, self.b_req_idx, b_page_len, self.kv_starts[:-1], cdiv(self.max_kv_seq_len, self.page_size), + self.page_size, self.kv_indices, ) else: @@ -99,12 +100,13 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): b_page_len = cdiv(self.b_seq_len, self.page_size) kv_starts[1:] = b_page_len.cumsum(0) kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size - repack_kv_index( - self.req_manager.req_to_page_indexs, + repack_paged_kv_index_from_tokens( + self.req_manager.req_to_token_indexs, self.b_req_idx, b_page_len, kv_starts[:-1], cdiv(self.max_kv_seq_len, self.page_size), + self.page_size, kv_indices, ) else: diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py index 687fa1a22..537aba759 100644 --- a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -391,17 +391,15 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token( - self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False - ): + def free_radix_cache_to_get_enough_token(self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None): assert self.mem_manager is not None need_pages = 0 can_use_pages = 0 if hasattr(self.mem_manager, "can_use_page_size") and self.page_size > 1 and b_seq_len is not None: - def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None, is_prefill=False): + def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None): need_new_pages = 0 - if is_prefill: + if b_ready_cache_len is not None: need_tokens_array = b_seq_len - b_ready_cache_len need_pages_array = (need_tokens_array + page_size - 1) // page_size need_new_pages = need_pages_array.sum() @@ -410,7 +408,7 @@ def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None, is_prefill= need_new_pages = mask.sum() return need_new_pages - need_pages = get_need_page_size(self.page_size, b_seq_len, b_ready_cache_len, is_prefill) + need_pages = get_need_page_size(self.page_size, b_seq_len, b_ready_cache_len) can_use_pages = self.mem_manager.can_use_page_size if need_token_num > self.mem_manager.can_use_mem_size or need_pages > can_use_pages: need_evict_single_token_num = need_token_num - self.mem_manager.can_use_mem_size diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index a60d0a942..f00b6546a 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -333,9 +333,7 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token( - self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False - ): + def free_radix_cache_to_get_enough_token(self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None): assert self.mem_manager is not None if need_token_num > self.mem_manager.can_use_mem_size: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 20b9c0e27..01ae6c9c5 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -340,9 +340,7 @@ def _match_radix_cache(self): self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len # 从 cpu 到 gpu 是流内阻塞操作 - g_infer_context.req_manager.mem_manager.set_prefix_cache_to_req( - self.req_idx, 0, ready_cache_len, value_tensor - ) + g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 @@ -460,7 +458,7 @@ def diverse_copy(self, req_manager, is_prefill): req = g_infer_context.requests_mapping[req_id] req.finish_status.set_status(FinishStatus.NO_FINISH) input_len = req.get_chuncked_input_token_len() - req_manager.mem_manager.set_prefix_cache_to_req(req.req_idx, prefix_len, input_len, cache_token_id) + req_manager.req_to_token_indexs[req.req_idx][prefix_len:input_len] = cache_token_id assert input_len == pre_input_len diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 448c0d987..cdbe8d4e5 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -78,11 +78,12 @@ def padded_prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( - input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len, True + input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len ) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc( - input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True + mem_indexes = g_infer_context.req_manager.alloc_token_indices( + input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len ) + g_infer_state_lock.release() if padded_req_num > 0: @@ -167,7 +168,7 @@ def padded_prepare_decode_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num, b_seq_len) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + mem_indexes = g_infer_context.req_manager.alloc_token_indices( b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len ) g_infer_state_lock.release() diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index e5e871d83..6b44428d4 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -56,10 +56,10 @@ def prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( - input_ids.shape[0], b_seq_len, b_ready_cache_len, True + input_ids.shape[0], b_seq_len, b_ready_cache_len ) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc( - input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + mem_indexes = g_infer_context.req_manager.alloc_token_indices( + input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len ) g_infer_state_lock.release() @@ -116,7 +116,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0], b_seq_len) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len) + mem_indexes = g_infer_context.req_manager.alloc_token_indices(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() model_input = ModelInput( diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b2c767c5d..b8012ae26 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -258,8 +258,8 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.mem_manager.alloc( - test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + mem_indexes = model_part.req_manager.alloc_token_indices( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len ) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") rank_id = model_kvargs["rank_id"] @@ -323,7 +323,7 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len) + mem_indexes = model_part.req_manager.alloc_token_indices(predict_ids.shape[0], b_req_idx, b_seq_len) max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 9d684fd27..5efa57fa4 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -124,8 +124,8 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc( - test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + mem_indexes = main_model.req_manager.alloc_token_indices( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len ).cuda() # Main model Prefill model_input = ModelInput( @@ -193,7 +193,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc( + mem_indexes = main_model.req_manager.alloc_token_indices( batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len ).cuda() From 020268115372d6525f62f0f8789cafd15dc4890b Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 3 Sep 2025 14:26:42 +0800 Subject: [PATCH 7/9] feat: replace page idxs with token idxs in paged_mem_manager --- lightllm/common/paged_mem_manager.py | 47 ++++----------- lightllm/common/req_manager.py | 57 +++++++++++-------- .../dynamic_prompt/paged_radix_cache.py | 32 ++--------- .../router/dynamic_prompt/radix_cache.py | 2 +- .../generic_padded_pre_process.py | 6 +- .../mode_backend/generic_pre_process.py | 6 +- 6 files changed, 57 insertions(+), 93 deletions(-) diff --git a/lightllm/common/paged_mem_manager.py b/lightllm/common/paged_mem_manager.py index dc97eae6e..f6bb25b5d 100755 --- a/lightllm/common/paged_mem_manager.py +++ b/lightllm/common/paged_mem_manager.py @@ -18,12 +18,6 @@ def cdiv(a, b): class PagedMemoryManager(MemoryManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - page_size = get_page_size() - self.mem_page_state = torch.arange( - 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_page_start = 0 - self.can_use_page_size = cdiv(self.size, page_size) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty( @@ -53,42 +47,23 @@ def check_cache_page_valid(self, values: torch.Tensor): return True def alloc(self, need_size) -> torch.Tensor: - if self.can_use_page_size < need_size: - raise RuntimeError( - f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {need_size}" - ) - new_pages = self.mem_page_state[self.mark_page_start : self.mark_page_start + need_size].cuda() - self.mark_page_start += need_size - self.can_use_page_size -= need_size - self.can_use_mem_size -= need_size * get_page_size() - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - return new_pages + assert need_size % get_page_size() == 0, "Need size must be a multiple of page size" + return super().alloc(need_size) def free(self, free_index: Union[torch.Tensor, List[int]]): - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - page_size = get_page_size() - if isinstance(free_index, list): - free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True) - - if len(free_index) == 0: - return + if page_size == 1: + return super().free(free_index) + if isinstance(free_index, list): + free_index = torch.tensor(free_index) base_free_index = free_index[free_index % page_size == 0] - page_indices = base_free_index // page_size - for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序 - self.mark_page_start -= 1 - self.mem_page_state[self.mark_page_start] = page_idx - self.can_use_page_size += 1 - + if len(base_free_index) == 0: + return + token_idxs = base_free_index[:, None] + torch.arange(page_size, device=free_index.device) + token_idxs = token_idxs.flatten() + super().free(token_idxs) return def free_all(self): super().free_all() - page_size = get_page_size() - self.mark_page_start = 0 - self.can_use_page_size = cdiv(self.size, page_size) - self.mem_page_state = torch.arange( - 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 5435d3817..7797f280c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -71,7 +71,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num - def expand_by_page_size(self, b_token_len, page_size): + def _expand_by_page_size(self, b_token_len, page_size): # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 b_page_len = cdiv(b_token_len, page_size) need_pages_num = b_page_len.sum() @@ -82,50 +82,57 @@ def expand_by_page_size(self, b_token_len, page_size): p_token_len[last_page_positions] = remainders return need_pages_num, b_page_len, p_token_len - def alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): + def _alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): if b_ready_cache_len is not None: # prefill - b_req_idx = b_req_idx.cuda() - b_seq_len = b_seq_len.cuda() - b_ready_cache_len = b_ready_cache_len.cuda() + b_seq_len = b_seq_len.cpu() + b_ready_cache_len = b_ready_cache_len.cpu() b_token_len = b_seq_len - b_ready_cache_len - total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size) - allocated_pages = self.mem_manager.alloc(total_pages_needed) - - def get_offsets_by_length(b_len, max_len): - # 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4] - offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device) - offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1) - return torch.masked_select(offsets, offset_mask) - - token_offsets = get_offsets_by_length(p_token_len, page_size) - page_bases = allocated_pages * page_size - return torch.repeat_interleave(page_bases, p_token_len) + token_offsets + total_pages_needed, b_page_len, p_token_len = self._expand_by_page_size(b_token_len, page_size) + paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) + pages = paged_token_idxs.view(-1, page_size) + mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) + return pages[mask] else: # decode b_seq_len = b_seq_len.cuda() b_req_idx = b_req_idx.cuda() need_new_page_mask = (b_seq_len - 1) % page_size == 0 - new_pages_num = need_new_page_mask.sum() + new_pages_num = need_new_page_mask.sum().cpu() token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) if new_pages_num > 0: - new_pages = self.mem_manager.alloc(new_pages_num) - token_idxs[need_new_page_mask] = new_pages * page_size + new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size).cuda() + token_idxs[need_new_page_mask] = new_pages_tokens[::page_size] mask = ~need_new_page_mask if mask.any(): seq_lens = b_seq_len[mask] - token_idxs[mask] = ( - self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size - + (seq_lens - 1) % page_size - ) + token_idxs[mask] = self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] + 1 return token_idxs + def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): + page_size = get_page_size() + if page_size == 1: + return 0 + + need_new_pages = 0 + if b_ready_cache_len is not None: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = (need_tokens_array + page_size - 1) // page_size + need_new_pages = need_pages_array.sum() + else: + mask = (b_seq_len - 1) % page_size == 0 + need_new_pages = mask.sum() + return need_new_pages * page_size + + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) + def alloc_token_indices(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None) -> torch.Tensor: page_size = get_page_size() if page_size > 1: - return self.alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) + return self._alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) else: return self.mem_manager.alloc(need_size) diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py index 537aba759..71f41d38a 100644 --- a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -391,31 +391,10 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None): + def free_radix_cache_to_get_enough_token(self, need_token_num): assert self.mem_manager is not None - need_pages = 0 - can_use_pages = 0 - if hasattr(self.mem_manager, "can_use_page_size") and self.page_size > 1 and b_seq_len is not None: - - def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None): - need_new_pages = 0 - if b_ready_cache_len is not None: - need_tokens_array = b_seq_len - b_ready_cache_len - need_pages_array = (need_tokens_array + page_size - 1) // page_size - need_new_pages = need_pages_array.sum() - else: - mask = (b_seq_len - 1) % page_size == 0 - need_new_pages = mask.sum() - return need_new_pages - - need_pages = get_need_page_size(self.page_size, b_seq_len, b_ready_cache_len) - can_use_pages = self.mem_manager.can_use_page_size - if need_token_num > self.mem_manager.can_use_mem_size or need_pages > can_use_pages: - need_evict_single_token_num = need_token_num - self.mem_manager.can_use_mem_size - need_evict_page_token_num = (need_pages - can_use_pages) * self.page_size - need_evict_token_num = max(need_evict_single_token_num, need_evict_page_token_num) - remaining_tokens = self.get_tree_total_tokens_num() - self.get_refed_tokens_num() - need_evict_token_num = min(need_evict_token_num, remaining_tokens) + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size release_mems = [] def release_mem(mem_index): @@ -423,7 +402,6 @@ def release_mem(mem_index): return self.evict(need_evict_token_num, release_mem) - if release_mems: - mem_index = torch.concat(release_mems) - self.mem_manager.free(mem_index) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) return diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index f00b6546a..65ec4354b 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -333,7 +333,7 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None): + def free_radix_cache_to_get_enough_token(self, need_token_num): assert self.mem_manager is not None if need_token_num > self.mem_manager.can_use_mem_size: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index cdbe8d4e5..5d94996a2 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -77,9 +77,10 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + token_num = g_infer_context.req_manager.calc_real_need_token_num( input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) mem_indexes = g_infer_context.req_manager.alloc_token_indices( input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len ) @@ -167,7 +168,8 @@ def padded_prepare_decode_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num, b_seq_len) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0] - padded_req_num, b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) mem_indexes = g_infer_context.req_manager.alloc_token_indices( b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 6b44428d4..2ed7cf51d 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -55,9 +55,10 @@ def prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + token_num = g_infer_context.req_manager.calc_real_need_token_num( input_ids.shape[0], b_seq_len, b_ready_cache_len ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) mem_indexes = g_infer_context.req_manager.alloc_token_indices( input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len ) @@ -115,7 +116,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0], b_seq_len) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) mem_indexes = g_infer_context.req_manager.alloc_token_indices(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() From 64f649f17d2a71ce019576eca5f407fd37f8fbf2 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 3 Sep 2025 20:10:18 +0800 Subject: [PATCH 8/9] feat: remove page_size_variable mode --- lightllm/common/basemodel/basemodel.py | 4 +- lightllm/common/basemodel/cuda_graph.py | 4 +- lightllm/common/deepseek2_mem_manager.py | 8 +- .../common/deepseek2_paged_mem_manager.py | 25 ---- lightllm/common/mem_manager.py | 30 +++-- lightllm/common/mem_utils.py | 4 - lightllm/common/paged_mem_manager.py | 69 ---------- lightllm/common/req_manager.py | 85 ++++++------ .../deepseek2/flashattention_infer_struct.py | 27 ++-- .../models/deepseek2/flashinfer_struct.py | 41 ++---- .../layer_infer/transformer_layer_infer.py | 68 +--------- lightllm/models/deepseek2/model.py | 63 +++++---- .../triton_kernel/repack_kv_index.py | 96 +++++--------- .../llama/flashattention_infer_struct.py | 35 ++--- lightllm/models/llama/flashinfer_struct.py | 73 ++++------- .../layer_infer/transformer_layer_infer.py | 123 +----------------- lightllm/server/api_cli.py | 4 +- lightllm/server/api_start.py | 7 - .../model_infer/mode_backend/base_backend.py | 4 +- .../generic_padded_pre_process.py | 4 +- .../mode_backend/generic_pre_process.py | 4 +- lightllm/utils/envs_utils.py | 6 +- .../benchmark/static_inference/model_infer.py | 6 +- .../static_inference/model_infer_mtp.py | 4 +- 24 files changed, 223 insertions(+), 571 deletions(-) delete mode 100755 lightllm/common/deepseek2_paged_mem_manager.py delete mode 100755 lightllm/common/paged_mem_manager.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 6aba997c8..5a34b4993 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -687,7 +687,7 @@ def _check_max_len_infer(self): b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - mem_indexes = self.req_manager.alloc_token_indices( + mem_indexes = self.req_manager.alloc_mem_indices( len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len ).cuda() total_token_num = self.batch_max_tokens @@ -764,7 +764,7 @@ def _autotune_warmup(self): b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") total_token_num = input_len b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") - mem_indexes = self.req_manager.alloc_token_indices( + mem_indexes = self.req_manager.alloc_mem_indices( len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len ).cuda() model_input = ModelInput( diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 4dabad0ff..6a2eac4b0 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -202,7 +202,7 @@ def warmup(self, model): b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - mem_indexes = model.req_manager.alloc_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda() + mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda() model_input = ModelInput( batch_size=batch_size, @@ -258,7 +258,7 @@ def warmup_overlap(self, model): b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - mem_indexes = model.req_manager.alloc_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda() + mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 6ddec24e2..75a4c3039 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -8,6 +8,7 @@ from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -20,7 +21,12 @@ def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + self.kv_buffer = torch.empty( + (layer_num, (size // page_size + 1) * page_size, head_num, head_dim), + dtype=dtype, + device="cuda", + ) # todo, etp or edp use the same work buffer here # also it can be used for any kernels for work buffer witout save info only diff --git a/lightllm/common/deepseek2_paged_mem_manager.py b/lightllm/common/deepseek2_paged_mem_manager.py deleted file mode 100755 index 6067a7fb1..000000000 --- a/lightllm/common/deepseek2_paged_mem_manager.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import numpy as np -from .deepseek2_mem_manager import Deepseek2MemoryManager -from .paged_mem_manager import PagedMemoryManager -from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_page_size - - -def cdiv(a, b): - return (a + b - 1) // b - - -logger = init_logger(__name__) - - -class Deepseek2PagedMemoryManager(PagedMemoryManager, Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty( - (layer_num, cdiv(size, get_page_size()) * get_page_size(), head_num, head_dim), - dtype=dtype, - device="cuda", - ) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..26ec7c970 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -2,6 +2,7 @@ import os import torch import torch.distributed as dist +import triton from typing import List, Union from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger @@ -9,7 +10,7 @@ from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args, get_page_size from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id @@ -81,7 +82,12 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch # 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + self.kv_buffer = torch.empty( + (layer_num, (size // page_size + 1) * page_size, 2 * head_num, head_dim), + dtype=dtype, + device="cuda", + ) def alloc_kv_move_buffer(self, max_req_total_len): """ @@ -244,6 +250,7 @@ def _free_buffers(self): self.kv_buffer = None def alloc(self, need_size) -> torch.Tensor: + assert need_size % get_page_size() == 0, "Need size must be a multiple of page size" if need_size > self.mark_end - self.mark_start: logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") assert False, "error alloc state" @@ -265,18 +272,25 @@ def free(self, free_index: Union[torch.Tensor, List[int]]): """ end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + page_size = get_page_size() + free_len = page_size * triton.cdiv(len(free_index), page_size) + start = self.mark_start - free_len + assert start >= 0, f"error free state start: {self.mark_start} free len {free_len}" if isinstance(free_index, list): - self.mem_state.numpy()[start:end] = free_index + free_index = torch.tensor(free_index) + + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + if page_size > 1: + base_free_index = free_index[free_index % page_size == 0] + token_idxs = base_free_index[:, None] + torch.arange(page_size) + self.mem_state[start:end] = token_idxs.flatten() else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 self.mem_state[start:end] = free_index - self.mark_start -= len(free_index) + self.mark_start -= free_len - self.can_use_mem_size += len(free_index) + self.can_use_mem_size += free_len self.shared_can_use_token_num.set_value(self.can_use_mem_size) if self.can_use_mem_size == len(self.mem_state): diff --git a/lightllm/common/mem_utils.py b/lightllm/common/mem_utils.py index aab7adc08..dfb8e849d 100644 --- a/lightllm/common/mem_utils.py +++ b/lightllm/common/mem_utils.py @@ -4,7 +4,6 @@ from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager -from lightllm.common.paged_mem_manager import PagedMemoryManager from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -29,9 +28,6 @@ def select_mem_manager_class(mode): elif "export_fp8kv_calibration" in mode: memory_manager_class = ExportCalibrationMemoryManager logger.info("Using mode export fp8kv calibration") - elif "page_size_variable" in mode: - memory_manager_class = PagedMemoryManager - logger.info("Page size will be variable") else: memory_manager_class = MemoryManager logger.info("Model kv cache using mode normal") diff --git a/lightllm/common/paged_mem_manager.py b/lightllm/common/paged_mem_manager.py deleted file mode 100755 index f6bb25b5d..000000000 --- a/lightllm/common/paged_mem_manager.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import numpy as np -from .mem_manager import MemoryManager -from typing import List, Union -from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_unique_server_name, get_page_size -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from lightllm.utils.dist_utils import get_current_rank_in_node - - -def cdiv(a, b): - return (a + b - 1) // b - - -logger = init_logger(__name__) - - -class PagedMemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty( - (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), - dtype=dtype, - device="cuda", - ) - - # 要求长度必须是page_size的整数倍,page内token索引必须连续 - def check_cache_page_valid(self, values: torch.Tensor): - end = len(values) - assert end % self.page_size == 0, "Values length must be a multiple of page size" - total_pages = end // self.page_size - for page_idx in range(total_pages): - values_start = page_idx * self.page_size - values_end = min((page_idx + 1) * self.page_size, end) - page_token_idxs = values[values_start:values_end] - if len(page_token_idxs) > 1: - expected_idxs = torch.arange( - page_token_idxs[0], - page_token_idxs[0] + len(page_token_idxs), - dtype=page_token_idxs.dtype, - device=page_token_idxs.device, - ) - if not torch.equal(page_token_idxs, expected_idxs): - return False - return True - - def alloc(self, need_size) -> torch.Tensor: - assert need_size % get_page_size() == 0, "Need size must be a multiple of page size" - return super().alloc(need_size) - - def free(self, free_index: Union[torch.Tensor, List[int]]): - page_size = get_page_size() - if page_size == 1: - return super().free(free_index) - - if isinstance(free_index, list): - free_index = torch.tensor(free_index) - base_free_index = free_index[free_index % page_size == 0] - if len(base_free_index) == 0: - return - token_idxs = base_free_index[:, None] + torch.arange(page_size, device=free_index.device) - token_idxs = token_idxs.flatten() - super().free(token_idxs) - return - - def free_all(self): - super().free_all() diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 7797f280c..d65968025 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,5 +1,6 @@ import torch import collections +import triton from lightllm.utils.log_utils import init_logger from .mem_manager import MemoryManager from typing import List, Optional @@ -11,10 +12,6 @@ logger = init_logger(__name__) -def cdiv(a, b): - return (a + b - 1) // b - - class _ReqNode: def __init__(self, index): self.index = index @@ -71,25 +68,60 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) + + def alloc_mem_indices(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None) -> torch.Tensor: + page_size = get_page_size() + if page_size > 1 and b_req_idx is not None and b_seq_len is not None: + return self._alloc_paged_mem_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) + else: + return self.mem_manager.alloc(need_size) + + def alloc(self): + return self.req_list.alloc() + + def free(self, free_req_indexes: List[int], free_token_index): + for req_index in free_req_indexes: + self.req_list.free(req_index) + + if self.req_list.is_all_free(): + logger.debug(f"freed all request size {self.req_list.can_alloc_size}") + self.mem_manager.free(free_token_index) + + def free_req(self, free_req_index: int): + self.req_list.free(free_req_index) + if self.req_list.is_all_free(): + logger.debug(f"freed all request size {self.req_list.can_alloc_size}") + return + + def free_token(self, free_token_index): + self.mem_manager.free(free_token_index) + return + + def free_all(self): + self.req_list = _ReqLinkedList(self.max_request_num) + return + def _expand_by_page_size(self, b_token_len, page_size): - # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 - b_page_len = cdiv(b_token_len, page_size) + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> p_token_len = [4,4,1,4,4,1,4,4,1], page_size = 4 + b_page_len = triton.cdiv(b_token_len, page_size) need_pages_num = b_page_len.sum() p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) cumsum_pages = torch.cumsum(b_page_len, dim=0) last_page_positions = cumsum_pages - 1 remainders = b_token_len - (b_page_len - 1) * page_size p_token_len[last_page_positions] = remainders - return need_pages_num, b_page_len, p_token_len + return need_pages_num, p_token_len - def _alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): + def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): if b_ready_cache_len is not None: # prefill b_seq_len = b_seq_len.cpu() b_ready_cache_len = b_ready_cache_len.cpu() b_token_len = b_seq_len - b_ready_cache_len - total_pages_needed, b_page_len, p_token_len = self._expand_by_page_size(b_token_len, page_size) + total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) pages = paged_token_idxs.view(-1, page_size) mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) @@ -126,41 +158,6 @@ def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): need_new_pages = mask.sum() return need_new_pages * page_size - def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): - return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) - - def alloc_token_indices(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None) -> torch.Tensor: - page_size = get_page_size() - if page_size > 1: - return self._alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) - else: - return self.mem_manager.alloc(need_size) - - def alloc(self): - return self.req_list.alloc() - - def free(self, free_req_indexes: List[int], free_token_index): - for req_index in free_req_indexes: - self.req_list.free(req_index) - - if self.req_list.is_all_free(): - logger.debug(f"freed all request size {self.req_list.can_alloc_size}") - self.mem_manager.free(free_token_index) - - def free_req(self, free_req_index: int): - self.req_list.free(free_req_index) - if self.req_list.is_all_free(): - logger.debug(f"freed all request size {self.req_list.can_alloc_size}") - return - - def free_token(self, free_token_index): - self.mem_manager.free(free_token_index) - return - - def free_all(self): - self.req_list = _ReqLinkedList(self.max_request_num) - return - class ReqSamplingParamsManager: """ diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index 5355990f1..be979e288 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -2,15 +2,12 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_page_size -def cdiv(a, b): - return (a + b - 1) // b - - class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): _shared_page_table_buffer = None @@ -45,25 +42,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_k = self.b1_cu_kv_seq_len max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - length = cdiv(model.graph_max_len_in_batch, self.page_size) + length = triton.cdiv(model.graph_max_len_in_batch, self.page_size) page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( self.batch_size, length ) else: - length = cdiv(self.max_len_in_batch, self.page_size) + length = triton.cdiv(self.max_len_in_batch, self.page_size) self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device) - if "page_size_variable" in model.mode: - length = cdiv(max_seq_len_k, self.page_size) - token_indexs = model.req_manager.req_to_token_indexs[ - self.b_req_idx, : length * self.page_size : self.page_size - ] - self.page_table[:, :length].copy_(token_indexs // self.page_size) - self.page_table[:, length:].fill_(0) - else: - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] - ) - self.page_table[:, max_seq_len_k:].fill_(0) + length = triton.cdiv(max_seq_len_k, self.page_size) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table[:, :length].copy_(token_indexs // self.page_size) + self.page_table[:, length:].fill_(0) return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index 1cb007377..213b89ad7 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -2,13 +2,10 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.envs_utils import get_env_start_args, get_page_size -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index, repack_paged_kv_index_from_tokens - - -def cdiv(a, b): - return (a + b - 1) // b +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): @@ -28,7 +25,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) - length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) + length = triton.cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ : self.batch_size * length @@ -39,27 +36,17 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): dtype=torch.int32, device=input_ids.device, ) - if "page_size_variable" in model.mode: - b_page_len = cdiv(self.b_seq_len, self.page_size) - self.kv_starts[1:] = b_page_len.cumsum(0) - repack_paged_kv_index_from_tokens( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - b_page_len, - self.kv_starts[:-1], - cdiv(self.max_len_in_batch, self.page_size), - self.page_size, - self.kv_indices, - ) - else: - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) + b_page_len = triton.cdiv(self.b_seq_len, self.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.max_len_in_batch, self.page_size), + self.page_size, + self.kv_indices, + ) if self.decode_wrapper is None: self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.flashinfer_extra_state.workspace_buffer, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ea2a80336..5ca0d3215 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -93,18 +93,6 @@ def _bind_attention(self): self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self ) - elif "page_size_variable" in self.mode: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - if get_env_start_args().enable_fa3: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention_paged, self - ) - elif get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer_paged, self - ) - else: - raise Exception("Page size variable mode is not supported in other backends.") else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) if get_env_start_args().enable_fa3: @@ -559,35 +547,6 @@ def _context_attention_kernel_origin_fp8( def _token_gqa_decode_attention_flashattention( self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - k_descale, v_descale = None, None - o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=self.softmax_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o_tensor - - def _token_gqa_decode_attention_flashattention_paged( - self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) @@ -624,30 +583,15 @@ def _token_gqa_decode_attention_flashinfer( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) + k_nope = kv[:, :, : -self.qk_rope_head_dim] + k_rope = kv[:, :, -self.qk_rope_head_dim :] infer_state.decode_wrapper.run( q_nope, q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - out=o_tensor, - return_lse=False, - ) - return o_tensor - - def _token_gqa_decode_attention_flashinfer_paged( - self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) - - infer_state.decode_wrapper.run( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank), - kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim), + k_nope if infer_state.page_size == 1 else k_nope.reshape(-1, infer_state.page_size, 1, self.kv_lora_rank), + k_rope + if infer_state.page_size == 1 + else k_rope.reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim), out=o_tensor, return_lse=False, ) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 50771e015..2131539fa 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -10,7 +10,6 @@ from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.common.deepseek2_paged_mem_manager import Deepseek2PagedMemoryManager from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager from lightllm.utils.log_utils import init_logger from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale @@ -22,35 +21,6 @@ logger = init_logger(__name__) -class DeepSeek2FlashInferStateExtraInfo: - def __init__(self, model): - num_heads = model.config["num_attention_heads"] - self.tp_q_head_num = num_heads // get_dp_world_size() - self.qk_nope_head_dim = model.qk_nope_head_dim - self.qk_rope_head_dim = model.qk_rope_head_dim - self.kv_lora_rank = model.kv_lora_rank - self.q_data_type = model.data_type - self.kv_data_type = model.data_type - self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - if model.config["rope_scaling"] is not None: - rope_scaling = model.config["rope_scaling"] - mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) - scaling_factor = rope_scaling["factor"] - if mscale_all_dim: - mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - @ModelRegistry(["deepseek_v2", "deepseek_v3"]) class Deepseek2TpPartModel(LlamaTpPartModel): # weight class @@ -98,10 +68,6 @@ def _init_mem_manager(self): manager_class = Deepseek2MemoryManager if "triton_fp8kv" in self.mode: manager_class = Deepseek2FP8KVMemoryManager - elif "page_size_variable" in self.mode: - manager_class = Deepseek2PagedMemoryManager - elif self.mode: - raise ValueError(f"Unsupported mode for deepseek2: {self.mode}") # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 @@ -202,3 +168,32 @@ def _init_to_get_yarn_rotary(self): def _context_forward(self, input_ids, infer_state): predict_logics = super()._context_forward(input_ids, infer_state) return predict_logics + + +class DeepSeek2FlashInferStateExtraInfo: + def __init__(self, model): + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale diff --git a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py index 4cb7012b2..65a165656 100644 --- a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py +++ b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py @@ -70,7 +70,7 @@ def _fwd_kernel_repack_page_kv_index_from_tokens( @torch.no_grad() -def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): +def repack_kv_index(req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index): batch_size = req_index.shape[0] # flashinfer requires out_kv_index to be zeroed before use out_kv_index.zero_() @@ -80,51 +80,35 @@ def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv triton.cdiv(max_seq_len, BLOCK), ) - _fwd_kernel_repack_kv_index[grid]( - kv_index, - req_index, - out_kv_index, - seq_len, - start_loc, - kv_index.stride(0), - SEQ_BLOCK=BLOCK, - num_warps=8, - num_stages=1, - ) + if page_size > 1: + _fwd_kernel_repack_page_kv_index_from_tokens[grid]( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + req_to_token_indexs.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + else: + _fwd_kernel_repack_kv_index[grid]( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + req_to_token_indexs.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) return -@torch.no_grad() -def repack_paged_kv_index_from_tokens( - req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index -): - batch_size = req_index.shape[0] - out_kv_index.zero_() - - BLOCK = 64 - grid = ( - batch_size, - triton.cdiv(max_seq_len, BLOCK), - ) - - _fwd_kernel_repack_page_kv_index_from_tokens[grid]( - req_to_token_indexs, - req_index, - out_kv_index, - seq_len, - start_loc, - page_size, - req_to_token_indexs.stride(0), - SEQ_BLOCK=BLOCK, - num_warps=8, - num_stages=1, - ) - return - - -def ref_repack_page_kv_index_with_token_input( - req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index -): +def repack_kv_ref(req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index): page_indexs = torch.zeros_like(req_to_token_indexs) valid_mask = req_to_token_indexs % page_size == 0 batch_size, seq_len_dim = req_to_token_indexs.shape @@ -136,12 +120,9 @@ def ref_repack_page_kv_index_with_token_input( (torch.where(valid_mask, req_to_token_indexs // page_size, 0) * valid_mask.int()).flatten(), ) - repack_kv_index(page_indexs, req_index, seq_len, start_loc, max_seq_len, out_kv_index) - - -def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - output[start : start + sl] = req_to_token_indexs[b][:sl] + for b, sl, start in zip(req_index, seq_len, start_loc): + out_kv_index[start : start + sl] = page_indexs[b][:sl] + return if __name__ == "__main__": @@ -172,25 +153,18 @@ def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output output = torch.zeros((b_seq_len.sum(),)).cuda().int() ref = torch.zeros((b_seq_len.sum(),)).cuda().int() - fn1 = lambda: repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) - fn2 = lambda: repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) - ms1 = triton.testing.do_bench(fn1) - ms2 = triton.testing.do_bench_cudagraph(fn2) - print(f"repack_kv_index: ref={ms1:.3f}ms, triton={ms2:.3f}ms") - assert torch.allclose(output.float(), ref.float()) - b_page_len = triton.cdiv(b_seq_len, PAGE_SIZE) page_output = torch.zeros((b_page_len.sum(),)).cuda().int() page_ref = torch.zeros((b_page_len.sum(),)).cuda().int() b_start_loc[1:] = b_page_len.cumsum(0)[:-1] max_seq_len = triton.cdiv(MAX_SEQ_LEN, PAGE_SIZE) - fn3 = lambda: ref_repack_page_kv_index_with_token_input( + fn1 = lambda: repack_kv_ref( req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_seq_len, PAGE_SIZE, page_ref ) - fn4 = lambda: repack_paged_kv_index_from_tokens( + fn2 = lambda: repack_kv_index( req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_seq_len, PAGE_SIZE, page_output ) - ms3 = triton.testing.do_bench(fn3) - ms4 = triton.testing.do_bench_cudagraph(fn4) - print(f"repack_paged_kv_index_from_tokens: ref={ms3:.3f}ms, triton={ms4:.3f}ms") + ms1 = triton.testing.do_bench(fn1) + ms2 = triton.testing.do_bench_cudagraph(fn2) + print(f"repack_kv_index: ref={ms1:.3f}ms, triton={ms2:.3f}ms") assert torch.allclose(page_output.float(), page_ref.float()) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 51c302b62..1c62052e3 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -2,6 +2,7 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.dist_utils import get_current_device_id @@ -9,10 +10,6 @@ from lightllm.common.basemodel.batch_objs import ModelInput -def cdiv(a, b): - return (a + b - 1) // b - - class FlashAttentionStateInfo(LlamaInferStateInfo): _shared_page_table_buffer = None @@ -33,38 +30,32 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - length = cdiv(self.max_seq_len, self.page_size) + length = triton.cdiv(self.max_seq_len, self.page_size) self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) - if "page_size_variable" in model.mode: - token_indexs = model.req_manager.req_to_token_indexs[ - self.b_req_idx, : length * self.page_size : self.page_size - ] - self.page_table.copy_(token_indexs // self.page_size) - else: - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table.copy_(token_indexs // self.page_size) else: # Meta information of flashattention for decoding self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - length = cdiv(model.graph_max_len_in_batch, self.page_size) + length = triton.cdiv(model.graph_max_len_in_batch, self.page_size) page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( self.batch_size, length ) else: - length = cdiv(self.max_len_in_batch, self.page_size) + length = triton.cdiv(self.max_len_in_batch, self.page_size) self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) - length = cdiv(max_seq_len_k, self.page_size) - if "page_size_variable" in model.mode: - token_indexs = model.req_manager.req_to_token_indexs[ - self.b_req_idx, : length * self.page_size : self.page_size - ] - self.page_table[:, :length].copy_(token_indexs // self.page_size) - else: - self.page_table[:, :length].copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) + length = triton.cdiv(max_seq_len_k, self.page_size) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table[:, :length].copy_(token_indexs // self.page_size) self.page_table[:, length:].fill_(0) if "offline_calibration_fp8kv" in model.mode: diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index 4c696186f..1d655d9ff 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -2,13 +2,10 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.utils.envs_utils import get_env_start_args, get_page_size -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index, repack_paged_kv_index_from_tokens - - -def cdiv(a, b): - return (a + b - 1) // b +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index class LlamaFlashInferStateInfo(LlamaInferStateInfo): @@ -28,7 +25,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: self.kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) - length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) + length = triton.cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ : self.batch_size * length @@ -41,28 +38,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): ) self.kv_starts = self.b1_cu_kv_seq_len.int() - if "page_size_variable" in model.mode: - b_page_len = cdiv(self.b_seq_len, self.page_size) + b_page_len = triton.cdiv(self.b_seq_len, self.page_size) + if self.page_size > 1: self.kv_starts[1:] = b_page_len.cumsum(0) self.kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size - repack_paged_kv_index_from_tokens( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - b_page_len, - self.kv_starts[:-1], - cdiv(self.max_kv_seq_len, self.page_size), - self.page_size, - self.kv_indices, - ) - else: - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_kv_seq_len, - self.kv_indices, - ) + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.max_kv_seq_len, self.page_size), + self.page_size, + self.kv_indices, + ) if self.decode_wrapper is None: self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, @@ -90,34 +78,25 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) - length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) + length = triton.cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) kv_indices = torch.empty( self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - if "page_size_variable" in model.mode: - b_page_len = cdiv(self.b_seq_len, self.page_size) + b_page_len = triton.cdiv(self.b_seq_len, self.page_size) + if self.page_size > 1: kv_starts[1:] = b_page_len.cumsum(0) kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size - repack_paged_kv_index_from_tokens( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - b_page_len, - kv_starts[:-1], - cdiv(self.max_kv_seq_len, self.page_size), - self.page_size, - kv_indices, - ) - else: - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - kv_starts[:-1], - self.max_kv_seq_len, - kv_indices, - ) + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + b_page_len, + kv_starts[:-1], + triton.cdiv(self.max_kv_seq_len, self.page_size), + self.page_size, + kv_indices, + ) self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, qo_indptr_buf=q_starts, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 66a169996..6191ba745 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -87,14 +87,6 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) - elif "page_size_variable" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._paged_context_attention_flashattention, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._paged_token_decode_attention_flashattention, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: self._context_attention_kernel = partial( LlamaTransformerLayerInfer._context_attention_flashattention, self @@ -107,16 +99,9 @@ def _bind_attention(self): raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") return elif get_env_start_args().enable_flashinfer_prefill: - if "page_size_variable" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._paged_context_attention_flashinfer_kernel, self - ) - elif not self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self - ) - else: - raise Exception(f"Unsupported mode for flashinfer backend: {self.mode}") + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self + ) else: self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) if "ppl_int8kv" in self.mode: @@ -181,12 +166,6 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) - elif "page_size_variable" in self.mode: - assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._paged_token_decode_attention_flashinfer, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: if get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( @@ -276,19 +255,6 @@ def _context_attention_flashinfer_kernel_fp8( def _context_attention_flashinfer_kernel( self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) - infer_state.prefill_wrapper.run( - q.view(q.shape[0], -1, self.head_dim_), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(q.shape[0], -1, self.head_dim_), - ) - return o_tensor - - def _paged_context_attention_flashinfer_kernel( - self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( @@ -352,9 +318,7 @@ def _context_attention_kernel_ppl_int8kv( ) return o_tensor - def _paged_context_attention_flashattention( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): + def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ ) @@ -384,36 +348,6 @@ def _paged_context_attention_flashattention( ) return o - def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - def _context_attention_flashattention_fp8( self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None ): @@ -604,21 +538,6 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) - infer_state.decode_wrapper.run( - q.view(calcu_shape1), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(calcu_shape1), - ) - return o_tensor - - def _paged_token_decode_attention_flashinfer( - self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ @@ -908,9 +827,7 @@ def _token_decode_attention_gqa_flashdecoding_vsm( alloc_tensor_func=self.alloc_tensor, ) - def _paged_token_decode_attention_flashattention( - self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): + def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ ) @@ -921,36 +838,6 @@ def _paged_token_decode_attention_flashattention( k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=sm_scale, - causal=False, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - - def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, k_cache=cache_k, diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index b52d72d14..ae9f7541d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -179,7 +179,7 @@ def make_argument_parser() -> argparse.ArgumentParser: nargs="+", help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv - | export_fp8kv_calibration | page_size_variable + | export_fp8kv_calibration triton_flashdecoding mode is for long context, current support llama llama2 qwen; triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; @@ -191,8 +191,6 @@ def make_argument_parser() -> argparse.ArgumentParser: Calibration need to disable cudagraph and use fa3 or flashinfer backend. ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; ppl_fp16 mode use ppl fast fp16 decode attention kernel; - page_size_variable allow to use page size > 1, use PAGE_SIZE env to set page size, - page_size_variable only support fa3 and flashinfer backend for now you need to read source code to make sure the supported detail mode for all models""", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6b89e8f1a..03c519d7b 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -126,13 +126,6 @@ def normal_or_p_d_start(args): "--enable_flashinfer_prefill and --enable_flashinfer_decode" ) assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" - if "page_size_variable" in args.mode: - assert args.enable_fa3 is True or ( - args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True - ), ( - "page_size_variable mode need enable fa3 or flashinfer, add --enable_fa3 or " - "--enable_flashinfer_prefill and --enable_flashinfer_decode" - ) # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 99830a8db..80ebb09d7 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -27,7 +27,7 @@ from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.distributed import dist_group_manager from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack @@ -140,7 +140,7 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) - radix_cache_class = PagedRadixCache if "page_size_variable" in self.mode else RadixCache + radix_cache_class = PagedRadixCache if get_page_size() > 1 else RadixCache self.radix_cache = ( radix_cache_class( get_unique_server_name(), diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 5d94996a2..dc57a9708 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -81,7 +81,7 @@ def padded_prepare_prefill_inputs( input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len ) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) - mem_indexes = g_infer_context.req_manager.alloc_token_indices( + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len ) @@ -170,7 +170,7 @@ def padded_prepare_decode_inputs( if g_infer_context.radix_cache is not None: token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0] - padded_req_num, b_seq_len) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) - mem_indexes = g_infer_context.req_manager.alloc_token_indices( + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len ) g_infer_state_lock.release() diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 2ed7cf51d..a0d434855 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -59,7 +59,7 @@ def prepare_prefill_inputs( input_ids.shape[0], b_seq_len, b_ready_cache_len ) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) - mem_indexes = g_infer_context.req_manager.alloc_token_indices( + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len ) g_infer_state_lock.release() @@ -118,7 +118,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In if g_infer_context.radix_cache is not None: token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) - mem_indexes = g_infer_context.req_manager.alloc_token_indices(b_seq_len.shape[0], b_req_idx, b_seq_len) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index c3a9469f0..e9e2162e3 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -160,11 +160,7 @@ def set_triton_autotune_level(level: int): @lru_cache(maxsize=None) def get_page_size(): - try: - args = get_env_start_args() - return int(os.getenv("PAGE_SIZE", 64)) if "page_size_variable" in args.mode else 1 - except: - return 1 + return int(os.getenv("PAGE_SIZE", 64)) g_model_init_done = False diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b8012ae26..d1cc9ff79 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -258,9 +258,7 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.alloc_token_indices( - test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len - ) + mem_indexes = model_part.req_manager.alloc_mem_indices(test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") rank_id = model_kvargs["rank_id"] @@ -323,7 +321,7 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.alloc_token_indices(predict_ids.shape[0], b_req_idx, b_seq_len) + mem_indexes = model_part.req_manager.alloc_mem_indices(predict_ids.shape[0], b_req_idx, b_seq_len) max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 5efa57fa4..d94f0de4a 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -124,7 +124,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.alloc_token_indices( + mem_indexes = main_model.req_manager.alloc_mem_indices( test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len ).cuda() # Main model Prefill @@ -193,7 +193,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.alloc_token_indices( + mem_indexes = main_model.req_manager.alloc_mem_indices( batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len ).cuda() From 6801c78554809de920777713fd15eee8636b291b Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 4 Sep 2025 19:33:09 +0800 Subject: [PATCH 9/9] feat: add b_last_mem_indx in the InferReq --- lightllm/common/basemodel/basemodel.py | 6 ++-- lightllm/common/basemodel/cuda_graph.py | 10 ++++-- lightllm/common/req_manager.py | 34 ++++++++++++------- .../server/router/model_infer/infer_batch.py | 1 + .../generic_padded_pre_process.py | 14 ++++++-- .../mode_backend/generic_pre_process.py | 16 +++++++-- lightllm/utils/envs_utils.py | 2 +- .../benchmark/static_inference/model_infer.py | 8 +++-- .../static_inference/model_infer_mtp.py | 7 ++-- 9 files changed, 67 insertions(+), 31 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5a34b4993..088df2291 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -687,9 +687,7 @@ def _check_max_len_infer(self): b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - mem_indexes = self.req_manager.alloc_mem_indices( - len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len - ).cuda() + mem_indexes = self.req_manager.alloc_mem_indices(len(dummy_input_ids), b_seq_len, b_ready_cache_len).cuda() total_token_num = self.batch_max_tokens b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") model_input = ModelInput( @@ -765,7 +763,7 @@ def _autotune_warmup(self): total_token_num = input_len b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") mem_indexes = self.req_manager.alloc_mem_indices( - len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len + len(dummy_input_ids), b_seq_len, b_ready_cache_len ).cuda() model_input = ModelInput( batch_size=1, diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 6a2eac4b0..3027f4ca3 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -201,8 +201,11 @@ def warmup(self, model): ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) + b_last_mem_index = torch.zeros_like(b_seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda() + mem_indexes = model.req_manager.alloc_mem_indices( + len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index + ).cuda() model_input = ModelInput( batch_size=batch_size, @@ -257,8 +260,11 @@ def warmup_overlap(self, model): ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) + b_last_mem_index = torch.zeros_like(b_seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda() + mem_indexes = model.req_manager.alloc_mem_indices( + len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index + ).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index d65968025..3f85d7f57 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -71,10 +71,21 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) - def alloc_mem_indices(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None) -> torch.Tensor: + def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None): + b_token_len = b_seq_len + if b_ready_cache_len is not None: + b_token_len = b_seq_len - b_ready_cache_len + b_token_len_cumsum = torch.cumsum(b_token_len, dim=0) + b_last_mem_index = mem_indices[b_token_len_cumsum - 1] + return b_last_mem_index + + # b_ready_cache_len为None时才需要b_last_mem_index + def alloc_mem_indices( + self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None + ) -> torch.Tensor: page_size = get_page_size() - if page_size > 1 and b_req_idx is not None and b_seq_len is not None: - return self._alloc_paged_mem_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) + if page_size > 1 and b_seq_len is not None: + return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index) else: return self.mem_manager.alloc(need_size) @@ -114,12 +125,11 @@ def _expand_by_page_size(self, b_token_len, page_size): p_token_len[last_page_positions] = remainders return need_pages_num, p_token_len - def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): + def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index): + b_seq_len = b_seq_len.cpu() if b_ready_cache_len is not None: # prefill - b_seq_len = b_seq_len.cpu() b_ready_cache_len = b_ready_cache_len.cpu() - b_token_len = b_seq_len - b_ready_cache_len total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) @@ -128,19 +138,17 @@ def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cach return pages[mask] else: # decode - b_seq_len = b_seq_len.cuda() - b_req_idx = b_req_idx.cuda() + assert b_last_mem_index is not None + b_last_mem_index = b_last_mem_index.cpu() need_new_page_mask = (b_seq_len - 1) % page_size == 0 - new_pages_num = need_new_page_mask.sum().cpu() + new_pages_num = need_new_page_mask.sum() token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) if new_pages_num > 0: - new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size).cuda() + new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size) token_idxs[need_new_page_mask] = new_pages_tokens[::page_size] - mask = ~need_new_page_mask if mask.any(): - seq_lens = b_seq_len[mask] - token_idxs[mask] = self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] + 1 + token_idxs[mask] = b_last_mem_index[mask] + 1 return token_idxs def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 01ae6c9c5..a5eb4c8b9 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -288,6 +288,7 @@ def __init__( self.shm_index = shm_index self.multimodal_params = multimodal_params self.vocab_size = vocab_size + self.last_kv_mem_index = -1 # 请求需要被暂停 self.wait_pause = False diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index dc57a9708..38079cb3f 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -82,8 +82,13 @@ def padded_prepare_prefill_inputs( ) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) mem_indexes = g_infer_context.req_manager.alloc_mem_indices( - input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len + input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len ) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() @@ -123,6 +128,7 @@ def padded_prepare_decode_inputs( b_req_idx = [] b_mtp_index = [] b_seq_len = [] + b_last_mem_index = [] for req in req_objs: run_reqs.append(req) b_req_idx.append(req.req_idx) @@ -132,6 +138,7 @@ def padded_prepare_decode_inputs( total_token_num += seq_len max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -164,6 +171,7 @@ def padded_prepare_decode_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token g_infer_state_lock.acquire() @@ -171,8 +179,10 @@ def padded_prepare_decode_inputs( token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0] - padded_req_num, b_seq_len) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) mem_indexes = g_infer_context.req_manager.alloc_mem_indices( - b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len + b_seq_len.shape[0] - padded_req_num, b_seq_len, b_last_mem_index=b_last_mem_index ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i] g_infer_state_lock.release() if padded_req_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index a0d434855..531f54869 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -59,9 +59,12 @@ def prepare_prefill_inputs( input_ids.shape[0], b_seq_len, b_ready_cache_len ) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) - mem_indexes = g_infer_context.req_manager.alloc_mem_indices( - input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len + mem_indexes = g_infer_context.req_manager.alloc_mem_indices(input_ids.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() model_input = ModelInput( @@ -90,6 +93,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = [] b_mtp_index = [] b_seq_len = [] + b_last_mem_index = [] for req in req_objs: run_reqs.append(req) b_req_idx.append(req.req_idx) @@ -99,6 +103,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In total_token_num += seq_len max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -112,13 +117,18 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) - mem_indexes = g_infer_context.req_manager.alloc_mem_indices(b_seq_len.shape[0], b_req_idx, b_seq_len) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i] g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index e9e2162e3..3effd1a40 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -160,7 +160,7 @@ def set_triton_autotune_level(level: int): @lru_cache(maxsize=None) def get_page_size(): - return int(os.getenv("PAGE_SIZE", 64)) + return int(os.getenv("PAGE_SIZE", 1)) g_model_init_done = False diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index d1cc9ff79..e0f262f93 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -258,7 +258,8 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.alloc_mem_indices(test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len) + mem_indexes = model_part.req_manager.alloc_mem_indices(test_data.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = model_part.req_manager.calc_last_mem_index_in_prefill(mem_indexes, b_seq_len, b_ready_cache_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") rank_id = model_kvargs["rank_id"] @@ -321,7 +322,10 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.alloc_mem_indices(predict_ids.shape[0], b_req_idx, b_seq_len) + mem_indexes = model_part.req_manager.alloc_mem_indices( + predict_ids.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + b_last_mem_index = mem_indexes max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index d94f0de4a..cdb1f592e 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -124,9 +124,8 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.alloc_mem_indices( - test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len - ).cuda() + mem_indexes = main_model.req_manager.alloc_mem_indices(test_data.shape[0], b_seq_len, b_ready_cache_len).cuda() + b_last_mem_index = main_model.req_manager.calc_last_mem_index_in_prefill(mem_indexes, b_seq_len, b_ready_cache_len) # Main model Prefill model_input = ModelInput( batch_size=batch_size, @@ -194,7 +193,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") mem_indexes = main_model.req_manager.alloc_mem_indices( - batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len + batch_size * (len(draft_models) + 1), nopad_b_seq_len, b_last_mem_index=b_last_mem_index ).cuda() model_input = ModelInput(