From 153a40929c5271597462e219708f4ba37ab59282 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Dec 2025 05:41:50 +0000 Subject: [PATCH 01/34] fix --- .../kv_cache_mem_manager/mem_manager.py | 35 ++++ .../common/kv_trans_kernel/kv_trans_v2.py | 99 +++++++++++ lightllm/server/api_cli.py | 6 + lightllm/server/core/objs/req.py | 28 ++++ lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/router/manager.py | 1 + .../server/router/model_infer/infer_batch.py | 1 + .../mode_backend/dp_backend/impl.py | 154 ++++++++++++++++- .../mode_backend/dp_backend/p2p_fix.py | 155 ++++++++++++++++++ .../server/router/model_infer/model_rpc.py | 9 +- 10 files changed, 485 insertions(+), 4 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index daa061b7b..9f94a2edd 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist from typing import List, Union +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -401,6 +402,40 @@ def get_index_kv_buffer(self, index): def load_index_kv_buffer(self, index, load_tensor_dict): self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + def copy_kv_from_other_dp_ranks( + self, + mem_managers: List["MemoryManager"], + move_token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + mem_indexes: torch.Tensor, + dp_size_in_node: int, + rank_in_dp: int, + ): + if not hasattr(self, "mem_ptrs_dict"): + self.mem_ptrs_dict = {} + for layer_index in range(self.layer_num): + mems_ptr = [] + for i in range(0, len(mem_managers)): + mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) + mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") + self.mem_ptrs_dict[layer_index] = mems_ptr + + input_mems = [] + for i in range(len(self.mem_managers)): + input_mems.append(self.mem_managers[i].kv_buffer.data_ptr()) + input_mems = torch.tensor(input_mems, dtype=torch.uint64, device="cuda") + + for layer_index in range(self.layer_num): + kv_trans_for_dp( + input_mems=input_mems[layer_index], + input_idx=move_token_indexes, + input_dp_idx=token_dp_indexes, + output=self.kv_buffer[layer_index], + output_idx=mem_indexes, + dp_size_in_node=dp_size_in_node, + rank_in_dp=rank_in_dp, + ) + class ReadOnlyStaticsMemoryManager: """ diff --git a/lightllm/common/kv_trans_kernel/kv_trans_v2.py b/lightllm/common/kv_trans_kernel/kv_trans_v2.py index 772de5f6c..912587e1e 100644 --- a/lightllm/common/kv_trans_kernel/kv_trans_v2.py +++ b/lightllm/common/kv_trans_kernel/kv_trans_v2.py @@ -191,3 +191,102 @@ def kv_trans_v2_for_d_node( num_warps=1, ) return + + +@triton.jit +def _kv_trans_for_dp_kernel( + input_mems_ptr, + input_stride_0, + input_stride_1, + input_stride_2, + input_token_idx_ptr, + input_token_dp_index_ptr, + output_ptr, + output_stride_0, + output_stride_1, + output_stride_2, + output_token_idx_ptr, + token_num: int, + head_num: int, + head_dim: int, + grid_count: int, + BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, + CARD_NUM_PER_D: tl.constexpr, + RANK_IN_DP: tl.constexpr, +): + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) + output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64) + + head_num_dim = head_num * head_dim + tid = tl.program_id(0) + + offs = tl.arange(0, BLOCK_SIZE) + while tid < token_num: + dp_index = tl.load(input_token_dp_index_ptr + tid) + mem_index = RANK_IN_DP + dp_index * CARD_NUM_PER_D + input_token_idx = tl.load(input_token_idx_ptr + tid) + output_token_idx = tl.load(output_token_idx_ptr + tid) + for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES): + cur_offs = block_idx * BLOCK_SIZE + offs + input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty)) + in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim) + tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim) + + tid += grid_count + + return + + +def kv_trans_for_dp( + input_mems: torch.Tensor, + input_idx: torch.Tensor, + input_dp_idx: torch.Tensor, + output: torch.Tensor, + output_idx: torch.Tensor, + dp_size_in_node: int, + rank_in_dp: int, +): + """ + input_mems 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。 + """ + assert input_mems.is_contiguous() + assert output.is_contiguous() + assert len(input_mems.shape) == 1 + assert len(output.shape) == 3 + assert len(input_idx) == len(output_idx) + assert len(output_idx) == len(input_dp_idx) + assert len(input_mems) % dp_size_in_node == 0 + + card_num_per_d = len(input_mems) // dp_size_in_node + + _, head_num, head_dim = output.shape + token_num = len(output_idx) + # 用较少的资源来做数据传输,防止占用过多的 sm 计算单元 + grid_count = 20 + BLOCK_SIZE = 256 + NUM_STAGES = 3 + grid = (grid_count,) + + _kv_trans_for_dp_kernel[grid]( + input_mems, + *output.stride(), + input_idx, + input_dp_idx, + output, + *output.stride(), + output_idx, + token_num=token_num, + head_num=head_num, + head_dim=head_dim, + grid_count=grid_count, + BLOCK_SIZE=BLOCK_SIZE, + NUM_STAGES=NUM_STAGES, + CARD_NUM_PER_D=card_num_per_d, + RANK_IN_DP=rank_in_dp, + num_warps=1, + ) + + return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index c4f1961d3..b5dc60749 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -566,4 +566,10 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="""Directory used to persist disk cache data. Defaults to a temp directory when not set.""", ) + parser.add_argument( + "--disable_dp_prompt_cache_fetch", + action="store_true", + default=False, + help="""Disable prefix prompt cache fetch for data parallel inference. Enabled by default""", + ) return parser diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 0d2e7ae38..61af36e5c 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -122,6 +122,10 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), + # 所有DP中的最大kv cache的长度 + ("dp_max_kv_len", ctypes.c_int), + # 拥有最大kv cache长度的dp_rank + ("dp_max_kv_rank", ctypes.c_int), ] def get_str(self): @@ -173,6 +177,7 @@ def init( self.alloc_shm_numpy_len = self.input_len + self.sample_params.max_new_tokens + 1024 # + 1024 for safe self.create_logprobs_shm_array() self.create_prompt_ids_shm_array() + self.create_kv_indexes_shm_array() self.chunked_prefill_size = chunked_prefill_size self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids self.mtp_accepted_token_num = 0 @@ -183,6 +188,9 @@ def init( self.post_init() self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size + # 初始化DP模式相关字段 + self.dp_max_kv_len = 0 + self.dp_max_kv_rank = -1 if get_env_start_args().enable_cpu_cache: self._fill_input_token_hash() return @@ -227,12 +235,32 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_kv_indexes_shm_array(self): + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}" + self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64) + self.shm_kv_indexes.create_shm() + return + + def link_kv_indexes_shm_array(self): + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}" + self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64) + self.shm_kv_indexes.link_shm() + return + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() def get_prompt_ids_numpy(self): return self.shm_prompt_ids.arr[: self.input_len] + def get_kv_indexes(self): + return self.shm_kv_indexes.arr[: self.input_len].tolist() + + def get_kv_indexes_numpy(self): + return self.shm_kv_indexes.arr[: self.input_len] + def to_router_rpc_obj(self): if hasattr(self, "multimodal_params"): return ( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 8a2727794..0e186c256 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -114,6 +114,7 @@ class StartArgs: enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) + disable_dp_prompt_cache_fetch: bool = field(default=False) # zmp ports router_port: int = field(default=None) detokenization_port: int = field(default=None) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 40607c2ce..b1d2805c2 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -139,6 +139,7 @@ async def wait_to_model_ready(self): info_queue=self.info_queue, mem_queue=self.mem_queues[(rank_id % node_world_size)], router_lock=self.router_lock, + mem_queues=self.mem_queues, ) ) tasks.append(task) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index ab2965887..f774fa3d9 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -361,6 +361,7 @@ def _init_all_state(self): self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() + self.shm_req.link_kv_indexes_shm_array() self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) # 更新 nixl pd 分离模式下, prefill 节点需要开始传输的起始位置 diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 633102571..d1aceab32 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -1,11 +1,14 @@ import torch import time import numpy as np +import os import torch.nn.functional as F +import torch.distributed as dist from typing import List, Tuple, Optional, Callable +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.router.model_infer.infer_batch import InferSamplingParams, g_infer_context, InferReq from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.server.router.model_infer.mode_backend.pre import ( padded_prepare_prefill_inputs, @@ -23,14 +26,20 @@ from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids from .control_state import DPControlState +from lightllm.common.mem_manager import MemoryManager +import torch.multiprocessing as mp + +min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 0) class DPChunkedPrefillBackend(ModeBackend): - def __init__(self) -> None: + def __init__(self, mem_queues=None) -> None: super().__init__() # 用于控制每一步是执行prefill 和 decode 还是跳过 self.control_state_machine = DPControlState(backend=self) + self.disable_dp_prompt_cache_fetch = get_env_start_args().disable_dp_prompt_cache_fetch + self.min_trans_token_num = min_trans_token_num # 在 mtp 模式下切换绑定的prefill 和 decode 函数 if get_env_start_args().mtp_mode: @@ -60,8 +69,149 @@ def __init__(self) -> None: self.decode = self.decode_normal self.classed_req_strict_prefill = False + if not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and mem_queues is not None: + self._init_dp_cache_fetch(mem_queues) return + def _init_dp_cache_fetch(self, mem_queues): + from .p2p_fix import reduce_tensor + + mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ + self.mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + + # 一些可以复用的通用功能函数 + def _init_reqs(self, reqs: List[Tuple]): + my_reqs = reqs + other_reqs = [] + if self.dp_size_in_node != 1: + dp_rank_in_node = self.dp_rank_in_node + my_reqs = [req for req in reqs if req[3] == dp_rank_in_node] + other_reqs = [req for req in reqs if req[3] != dp_rank_in_node] + + g_infer_state_lock.acquire() + infer_reqs = g_infer_context.add_reqs(my_reqs, init_prefix_cache=False) + if self.dp_size_in_node != 1 and not self.disable_dp_prompt_cache_fetch: + self._post_init_reqs(infer_reqs, other_reqs=other_reqs) + + g_infer_state_lock.release() + req_ids = [e[0] for e in my_reqs] + return req_ids + + def _match_radix_cache(self, shm_req): + input_token_ids = shm_req.shm_prompt_ids.arr[0 : shm_req.input_len] + key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") + key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 + share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + return share_node, kv_len, value_tensor + + def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = []): + my_match = [] + other_match = [] + # match all the reqs in this dp rank. + for r in infer_reqs: + if r.sampling_param.disable_prompt_cache: + continue + shm_req = r.shm_req + + _, kv_len, value_tensor = self._match_radix_cache(shm_req) + # only the first rank is ok + if self.rank_in_dp == 0: + with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): + if kv_len > shm_req.dp_max_kv_len: + shm_req.dp_max_kv_len = kv_len + shm_req.dp_max_kv_rank = self.dp_rank_in_node # 单机 + my_match.append(shm_req, kv_len, value_tensor) + + # match all the reqs in other dp ranks. + if self.rank_in_dp == 0: + for r in other_reqs: + _, shm_index, _, _ = r + shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(shm_index) + sampling_param = InferSamplingParams(shm_req, g_infer_context.vocab_size) + if sampling_param.disable_prompt_cache: + continue + shm_req.link_prompt_ids_shm_array() + shm_req.link_kv_indexes_shm_array() + + _, kv_len, value_tensor = self._match_radix_cache(shm_req) + with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): + if kv_len > shm_req.dp_max_kv_len: + shm_req.dp_max_kv_len = kv_len + shm_req.dp_max_kv_rank = self.dp_rank_in_node # 单机 + other_match.append(shm_req, kv_len, value_tensor) + + # wait all the ranks to finish the match + dist.barrier() + + # Copy the kv_indexes of this dp rank to other required req + for match in other_match: + shm_req, kv_len, value_tensor = match + if shm_req.dp_max_kv_rank == self.dp_rank_in_node: + shm_req.shm_kv_indexes.arr[0:kv_len] = value_tensor + # release other dp_rank's shm_req + self.release_all_shm_reqs([match[0] for match in other_match]) + + # wait all the ranks to finish the copy + dist.barrier() + + # Perform a kv transfer, get all indexes and the corresponding dp_rank + move_token_indexes = [] + token_dp_indexes = [] + alloc_size = 0 + for r in my_match: + shm_req, kv_len, value_tensor = r + trans_len = shm_req.dp_max_kv_len - kv_len + if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node: + # Only copy kv_indexes that are not in this dp rank. + move_token_indexes.extend(shm_req.shm_kv_indexes.arr[kv_len : shm_req.dp_max_kv_len]) + token_dp_indexes.extend([shm_req.dp_max_kv_rank for _ in range(trans_len)]) + alloc_size += trans_len + + if alloc_size < self.min_trans_token_num: + return + + # Exit if alloc fails + try: + mem_indexes = self.model.mem_manager.alloc(alloc_size).cuda() + except Exception as e: + self.logger.error(f"error alloc mem manager: {str(e)}") + return + + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") + + # transfer kv + self.model.mem_manager.copy_kv_from_other_dp_ranks( + mem_managers=self.mem_managers, + move_token_indexes=move_token_indexes, + token_dp_indexes=token_dp_indexes, + mem_indexes=mem_indexes, + dp_size_in_node=self.dp_size_in_node, + rank_in_dp=self.rank_in_dp, + ) + + # 更新radix cache + start_index = 0 + for r in my_match: + shm_req, kv_len, value_tensor = r + trans_len = shm_req.dp_max_kv_len - kv_len + if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node: + new_value_tensor = mem_indexes[start_index : start_index + trans_len].cpu() + start_index += trans_len + if value_tensor is not None: + value_tensor = torch.cat((value_tensor, new_value_tensor), dim=0) + else: + value_tensor = new_value_tensor + g_infer_context.radix_cache.insert(shm_req.shm_prompt_ids.arr[0 : shm_req.dp_max_kv_len], value_tensor) + + # 更新infer_req的状态 + for r in infer_reqs: + r._match_radix_cache() + + def release_all_shm_reqs(self, shm_reqs): + for shm_req in shm_reqs: + g_infer_context.shm_req_manager.put_back_req_obj(shm_req) + def infer_loop(self): torch.cuda.set_device(get_current_device_id()) try: diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py new file mode 100644 index 000000000..62609c4c9 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py @@ -0,0 +1,155 @@ +# mypy: allow-untyped-defs +import multiprocessing +import os +import threading +from multiprocessing.reduction import ForkingPickler +from multiprocessing.util import register_after_fork +from typing import Union + +import torch +import torch.utils.hooks +from torch._namedtensor_internals import check_serializing_named_tensor +from torch.multiprocessing.reductions import storage_from_cache, shared_cache, StorageWeakRef +from torch.multiprocessing.reductions import reduce_nested_tensor, reduce_sparse_tensor, rebuild_tensor + + +def p2p_fix_rebuild_cuda_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + storage_cls, + dtype, + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, +): + # 因为接收进程在将 tensor 对应的 handle重新转化为指针的时候 + # 在其c++源码中会将当前显卡切换到storage_device再做操作,这样 + # 得到的指针可能不是接收进程当前上下文设备可以访问的,所以在这里 + # hack 修改了使用的 storage_device,这样后续tritonkernel同时 + # 访问几张显卡上的数据,进行p2p操作就不会出问题了。 + storage_device = torch.cuda.current_device() + # If storage_handle is None, storage points to nullptr. + if storage_handle is None or storage_size_bytes == 0: + storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) + else: + storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes)) + if storage is None: + torch.cuda._lazy_init() + storage = storage_cls._new_shared_cuda( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage) + else: + # We already ref counting this Storage, but producer needs new ref-counters to be released. + storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) + + _storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage + + t = torch._utils._rebuild_tensor( + torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def reduce_tensor(tensor): + if tensor.requires_grad and not tensor.is_leaf: + raise RuntimeError( + "Cowardly refusing to serialize non-leaf tensor which requires_grad, " + "since autograd does not support crossing process boundaries. " + "If you just want to transfer the data, call detach() on the tensor " + "before serializing (e.g., putting it on the queue)." + ) + + check_serializing_named_tensor(tensor) + torch.utils.hooks.warn_if_has_hooks(tensor) + + from torch.nested._internal.nested_tensor import NestedTensor + + if tensor.is_nested and not isinstance(tensor, NestedTensor): + return reduce_nested_tensor(tensor) + + if tensor.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + return reduce_sparse_tensor(tensor) + + storage = tensor._typed_storage() + + if storage._untyped_storage.device.type == "cuda": + ( + device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + tensor_offset = tensor.storage_offset() + shared_cache[handle] = StorageWeakRef(storage) + # _backward_hooks purposely omitted here, see + # Note [Don't serialize hooks] + from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import ( + p2p_fix_rebuild_cuda_tensor, + ) + + return ( + p2p_fix_rebuild_cuda_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, # tensor offset in its storage + type(storage), + tensor.dtype, + device, + handle, # identifier which CUDA allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ), + ) + + # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] + metadata = ( + tensor.storage_offset(), + tensor.size(), + tensor.stride(), + tensor.requires_grad, + ) + return (rebuild_tensor, (type(tensor), storage, metadata)) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 1bb625db0..a803d7ca5 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -48,12 +48,14 @@ def __init__( rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, mem_queue: mp.Queue, + mem_queues: List[mp.Queue] = None, ): super().__init__() self.args: StartArgs = args self.node_world_size = node_world_size self.info_queue = info_queue self.mem_queue = mem_queue + self.mem_queues = mem_queues self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -149,7 +151,7 @@ def init_model(self, kvargs): self.backend = NIXLDecodeNode(self.info_queue, self.mem_queue) elif self.args.dp > 1: - self.backend = DPChunkedPrefillBackend() + self.backend = DPChunkedPrefillBackend(mem_queues=self.mem_queues) elif use_reward_model: self.backend = RewardModelBackend() elif return_all_prompt_logprobs: @@ -223,6 +225,7 @@ def _init_env( rpc_event: mp.Event, rpc_finished_event: mp.Event, success_event: mp.Event, + mem_queues: List[mp.Queue] = None, ): import lightllm.utils.rpyc_fix_utils as _ @@ -237,7 +240,7 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue, mem_queues ) success_event.set() @@ -255,6 +258,7 @@ async def start_model_process( info_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, + mem_queues: List[mp.Queue] = None, ): import lightllm.utils.rpyc_fix_utils as _ @@ -272,6 +276,7 @@ async def start_model_process( rpc_event, rpc_finished_event, success_event, + mem_queues, ), ) proc.start() From 962ef2be89bb77ec3be87b1625e36875a2978445 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Dec 2025 05:42:22 +0000 Subject: [PATCH 02/34] fix --- .../kv_cache_mem_manager/mem_manager.py | 7 +-- lightllm/server/api_cli.py | 3 +- lightllm/server/api_start.py | 9 ++++ .../decode_node_impl/decode_impl_for_dp.py | 3 +- .../prefill_node_impl/prefill_impl_for_dp.py | 3 +- .../mode_backend/dp_backend/impl.py | 47 +++++++++++++------ .../decode_node_impl/decode_impl_for_dp.py | 3 +- .../prefill_node_impl/prefill_impl_for_dp.py | 3 +- .../server/router/model_infer/model_rpc.py | 2 +- 9 files changed, 50 insertions(+), 30 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 9f94a2edd..3ffac99a5 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -420,14 +420,9 @@ def copy_kv_from_other_dp_ranks( mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") self.mem_ptrs_dict[layer_index] = mems_ptr - input_mems = [] - for i in range(len(self.mem_managers)): - input_mems.append(self.mem_managers[i].kv_buffer.data_ptr()) - input_mems = torch.tensor(input_mems, dtype=torch.uint64, device="cuda") - for layer_index in range(self.layer_num): kv_trans_for_dp( - input_mems=input_mems[layer_index], + input_mems=self.mem_ptrs_dict[layer_index], input_idx=move_token_indexes, input_dp_idx=token_dp_indexes, output=self.kv_buffer[layer_index], diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index b5dc60749..932d4567f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -570,6 +570,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--disable_dp_prompt_cache_fetch", action="store_true", default=False, - help="""Disable prefix prompt cache fetch for data parallel inference. Enabled by default""", + help="""Disable prefix prompt cache fetch for data parallel inference. + Enabled by default, but currently not supported for pd separated mode""", ) return parser diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f73be30db..901ba7d4f 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -269,6 +269,15 @@ def normal_or_p_d_start(args): args.router_max_wait_tokens = 0 send_and_receive_node_ip(args) # 多机用于收发node ip + # PD 分离模式下必须禁用 DP prompt cache fetch,且 dp 必须 > 1 + if not args.disable_dp_prompt_cache_fetch: + if args.run_mode != "normal" or args.dp <= 1: + args.disable_dp_prompt_cache_fetch = True + logger.warning( + """PD split mode or dp <= 1 does not support dp_prompt_cache_fetch; + overriding disable_dp_prompt_cache_fetch to True""" + ) + set_env_start_args(args) logger.info(f"all start args:{args}") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py index 07288508a..2d637c817 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py @@ -10,9 +10,8 @@ class DPForDecodeNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__() + super().__init__(mem_queue=mem_queue) self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_strict_prefill = False return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py index f7f338c78..21999f6c7 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py @@ -10,10 +10,9 @@ class DPChunkedForPrefillNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__() + super().__init__(mem_queue=mem_queue) self.support_overlap = False self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_no_decode = True def init_custom(self): diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index d1aceab32..f39f91e5b 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -29,11 +29,11 @@ from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 0) +min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 1) class DPChunkedPrefillBackend(ModeBackend): - def __init__(self, mem_queues=None) -> None: + def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> None: super().__init__() # 用于控制每一步是执行prefill 和 decode 还是跳过 @@ -71,13 +71,31 @@ def __init__(self, mem_queues=None) -> None: self.classed_req_strict_prefill = False if not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and mem_queues is not None: self._init_dp_cache_fetch(mem_queues) + self.mem_queues = mem_queues + self.mem_queue = mem_queue return - def _init_dp_cache_fetch(self, mem_queues): - from .p2p_fix import reduce_tensor + def init_custom(self): + if not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and self.mem_queues is not None: + torch.cuda.set_device(get_current_device_id()) - mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - self.mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + from .p2p_fix import reduce_tensor + + mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ + + # 必须先设置 CUDA 设备上下文,因为序列化和反序列化 CUDA tensor 时 + # reduce_tensor 和 p2p_fix_rebuild_cuda_tensor 都会调用 torch.cuda.current_device() + + for _ in range(self.node_world_size - 1): + self.mem_queue.put(self.model.mem_manager) + + self.mem_managers = [] + for queue_index in range(len(self.mem_queues)): + if queue_index != self.rank_in_node: + self.mem_managers.append(self.mem_queues[queue_index].get(timeout=60)) + else: + self.mem_managers.append(self.model.mem_manager) + return # 一些可以复用的通用功能函数 def _init_reqs(self, reqs: List[Tuple]): @@ -92,6 +110,8 @@ def _init_reqs(self, reqs: List[Tuple]): infer_reqs = g_infer_context.add_reqs(my_reqs, init_prefix_cache=False) if self.dp_size_in_node != 1 and not self.disable_dp_prompt_cache_fetch: self._post_init_reqs(infer_reqs, other_reqs=other_reqs) + for r in infer_reqs: + r._match_radix_cache() g_infer_state_lock.release() req_ids = [e[0] for e in my_reqs] @@ -120,7 +140,7 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = if kv_len > shm_req.dp_max_kv_len: shm_req.dp_max_kv_len = kv_len shm_req.dp_max_kv_rank = self.dp_rank_in_node # 单机 - my_match.append(shm_req, kv_len, value_tensor) + my_match.append((shm_req, kv_len, value_tensor)) # match all the reqs in other dp ranks. if self.rank_in_dp == 0: @@ -138,7 +158,7 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = if kv_len > shm_req.dp_max_kv_len: shm_req.dp_max_kv_len = kv_len shm_req.dp_max_kv_rank = self.dp_rank_in_node # 单机 - other_match.append(shm_req, kv_len, value_tensor) + other_match.append((shm_req, kv_len, value_tensor)) # wait all the ranks to finish the match dist.barrier() @@ -174,7 +194,7 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = try: mem_indexes = self.model.mem_manager.alloc(alloc_size).cuda() except Exception as e: - self.logger.error(f"error alloc mem manager: {str(e)}") + self.logger.error(f"dp_i {self.dp_rank_in_node} error alloc mem manager: {str(e)}") return move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") @@ -198,15 +218,14 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node: new_value_tensor = mem_indexes[start_index : start_index + trans_len].cpu() start_index += trans_len + key = torch.tensor( + shm_req.shm_prompt_ids.arr[0 : shm_req.dp_max_kv_len], dtype=torch.int64, device="cpu" + ) if value_tensor is not None: value_tensor = torch.cat((value_tensor, new_value_tensor), dim=0) else: value_tensor = new_value_tensor - g_infer_context.radix_cache.insert(shm_req.shm_prompt_ids.arr[0 : shm_req.dp_max_kv_len], value_tensor) - - # 更新infer_req的状态 - for r in infer_reqs: - r._match_radix_cache() + g_infer_context.radix_cache.insert(key, value_tensor) def release_all_shm_reqs(self, shm_reqs): for shm_req in shm_reqs: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py index 878c76e0f..758ebf50b 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py @@ -10,9 +10,8 @@ class NIXLDPForDecodeNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__() + super().__init__(mem_queue=mem_queue) self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_strict_prefill = False return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py index b76eb9124..1aeddd267 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py @@ -10,10 +10,9 @@ class NIXLDPChunkedForPrefillNode(DPChunkedPrefillBackend): def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__() + super().__init__(mem_queue=mem_queue) self.support_overlap = False self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_no_decode = True self.nixl_prefill_chuncked_handle_func = self._prefill_chuncked_handle_func diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index a803d7ca5..37e8a9454 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -151,7 +151,7 @@ def init_model(self, kvargs): self.backend = NIXLDecodeNode(self.info_queue, self.mem_queue) elif self.args.dp > 1: - self.backend = DPChunkedPrefillBackend(mem_queues=self.mem_queues) + self.backend = DPChunkedPrefillBackend(mem_queue=self.mem_queue, mem_queues=self.mem_queues) elif use_reward_model: self.backend = RewardModelBackend() elif return_all_prompt_logprobs: From 0cd2b86f83e45ab6b7a5c2017db88dc23f46d872 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Mon, 3 Nov 2025 07:21:29 +0000 Subject: [PATCH 03/34] add enable_dp_prompt_cache_fetch --- .../router/model_infer/mode_backend/dp_backend/impl.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index f39f91e5b..0af16e370 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -76,7 +76,10 @@ def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> No return def init_custom(self): - if not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and self.mem_queues is not None: + self.enable_dp_prompt_cache_fetch = ( + not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and self.mem_queues is not None + ) + if self.enable_dp_prompt_cache_fetch: torch.cuda.set_device(get_current_device_id()) from .p2p_fix import reduce_tensor @@ -107,8 +110,8 @@ def _init_reqs(self, reqs: List[Tuple]): other_reqs = [req for req in reqs if req[3] != dp_rank_in_node] g_infer_state_lock.acquire() - infer_reqs = g_infer_context.add_reqs(my_reqs, init_prefix_cache=False) - if self.dp_size_in_node != 1 and not self.disable_dp_prompt_cache_fetch: + infer_reqs = g_infer_context.add_reqs(my_reqs, init_prefix_cache=not self.enable_dp_prompt_cache_fetch) + if self.enable_dp_prompt_cache_fetch: self._post_init_reqs(infer_reqs, other_reqs=other_reqs) for r in infer_reqs: r._match_radix_cache() @@ -209,6 +212,7 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = dp_size_in_node=self.dp_size_in_node, rank_in_dp=self.rank_in_dp, ) + self.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {alloc_size}") # 更新radix cache start_index = 0 From 0667d71cd210cc1d8680ce50d415a574ac26d4eb Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 4 Nov 2025 04:40:21 +0000 Subject: [PATCH 04/34] free_radix_cache_to_get_enough_token instead of skip --- .../router/model_infer/mode_backend/dp_backend/impl.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 0af16e370..bbb076e63 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -29,7 +29,7 @@ from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 1) +min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 128) class DPChunkedPrefillBackend(ModeBackend): @@ -193,12 +193,8 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = if alloc_size < self.min_trans_token_num: return - # Exit if alloc fails - try: - mem_indexes = self.model.mem_manager.alloc(alloc_size).cuda() - except Exception as e: - self.logger.error(f"dp_i {self.dp_rank_in_node} error alloc mem manager: {str(e)}") - return + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(alloc_size) + mem_indexes = self.model.mem_manager.alloc(alloc_size).cuda() move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") From 5a1e22d38efdc36fe753b71f5ff24b05b47fa868 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 4 Nov 2025 05:19:09 +0000 Subject: [PATCH 05/34] add use_openai_api, port, concurrency, history-turns, max-total-tokens to benchmark_sharegpt --- test/benchmark/service/benchmark_sharegpt.py | 346 ++++++++++++++----- 1 file changed, 265 insertions(+), 81 deletions(-) diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index c9f92f098..d6f26ff0b 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -26,6 +26,7 @@ import aiohttp import numpy as np from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase +from tqdm.asyncio import tqdm from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -63,112 +64,267 @@ def sample_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, int, int]]: + max_history_turns: int = 6, + max_total_tokens: int = 16384, +) -> List[Tuple[List[dict], str, int, int]]: # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] + # Filter out the conversations with at least 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= max_history_turns] print("read data set finish") + dataset = dataset[: num_requests * 3] + + def to_openai_role(role_value: str) -> str: + lower_value = role_value.lower() + if lower_value in ["human", "user", "system"]: + return "user" if lower_value != "system" else "system" + return "assistant" + + # Build messages and targets + built_examples: List[Tuple[List[dict], str]] = [] + for data in dataset: + convs = data.get("conversations", []) + if not convs: + continue + # Find the last assistant turn to be used as the completion target + last_assistant_idx = -1 + for idx in range(len(convs) - 1, -1, -1): + role_val = convs[idx].get("from") or convs[idx].get("role") or "assistant" + if to_openai_role(role_val) == "assistant": + last_assistant_idx = idx + break + if last_assistant_idx <= 0: + # Need at least one prompt message before the assistant response + continue + # Determine how many turns of history to keep before the target assistant turn + start_idx = max(0, last_assistant_idx - max_history_turns) + context_convs = convs[start_idx:last_assistant_idx] + completion_text = convs[last_assistant_idx].get("value") or convs[last_assistant_idx].get("content") or "" + if not completion_text: + continue + messages: List[dict] = [] + for turn in context_convs: + role_val = turn.get("from") or turn.get("role") or "user" + content_val = turn.get("value") or turn.get("content") or "" + if not content_val: + continue + messages.append({"role": to_openai_role(role_val), "content": content_val}) + if not messages: + continue + built_examples.append((messages, completion_text)) + + # Render prompts using chat template when possible + rendered_prompts: List[str] = [] + for messages, _ in built_examples: + rendered_text = None + try: + # Prefer using the tokenizer's chat template + rendered_text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + # Fallback rendering if chat template is unavailable + parts = [] + for m in messages: + parts.append(f"{m['role']}: {m['content']}") + parts.append("assistant:") + rendered_text = "\n".join(parts) + rendered_prompts.append(rendered_text) + # Tokenize the prompts and completions. - import random - - dataset = random.sample(dataset, num_requests * 3) - prompts = [prompt for prompt, _ in dataset] - completions = [completion for _, completion in dataset] - - prompt_token_ids = tokenizer(prompts).input_ids - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) - - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) + prompt_token_ids = tokenizer(rendered_prompts).input_ids if rendered_prompts else [] + completion_texts = [completion for _, completion in built_examples] + completion_token_ids = tokenizer(completion_texts).input_ids if completion_texts else [] + + tokenized_dataset: List[Tuple[List[dict], str, int, int]] = [] + for i in range(len(built_examples)): + messages, _ = built_examples[i] + prompt_len = len(prompt_token_ids[i]) + output_len = min(len(completion_token_ids[i]), 128) + tokenized_dataset.append((messages, rendered_prompts[i], prompt_len, output_len)) + + # Filter out too long or too short sequences. + filtered_dataset: List[Tuple[List[dict], str, int, int]] = [] + for messages, rendered_prompt, prompt_len, output_len in tokenized_dataset: if prompt_len < 4 or output_len < 4: - # Prune too short sequences. continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. + if (prompt_len + output_len) >= max_total_tokens: continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((messages, rendered_prompt, prompt_len, output_len)) # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) + sampled_requests = filtered_dataset[:num_requests] sum_len = 0 - for e in sampled_requests: - sum_len += e[1] + e[2] + for _, _, prompt_len, output_len in sampled_requests: + sum_len += prompt_len + output_len print("total tokens:", sum_len) return sampled_requests async def get_request( - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: + concurrency: int = None, +) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: input_requests = iter(input_requests) - for request in input_requests: - yield request - if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. - await asyncio.sleep(interval) - - -async def send_request(prompt: str, prompt_len: int, output_len: int) -> None: - request_start_time = time.time() - headers = {"Content-Type": "application/json"} - headers = {"User-Agent": "Benchmark Client"} - url = "http://localhost:8000/generate" - - data = { - "inputs": prompt, - "parameters": { - "do_sample": False, + if concurrency is not None: + # Concurrency-based request generation + # This generator will be consumed by the benchmark function + # which will manage the concurrency + for request in input_requests: + yield request + else: + # Rate-based request generation (original logic) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +async def send_request( + messages: List[dict], + rendered_prompt: str, + prompt_len: int, + output_len: int, + use_openai_api: bool, + port: int, + pbar=None, +) -> None: + if use_openai_api: + # Use OpenAI API to send the request. + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = f"http://localhost:{port}/v1/chat/completions" + + data = { + "model": "DeepSeek-R1", + "messages": messages, + "top_k": 1, + "top_p": 1.0, + "temperature": 0, + "stream": True, "ignore_eos": True, - "max_new_tokens": output_len, - # 'temperature': 0.1, - }, - } - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout) as session: - while True: + "max_tokens": output_len, + } + timeout = aiohttp.ClientTimeout(total=3 * 3600) + receive_n = 1 + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=data) as response: + chunks = [] + text = "" + start_time = time.time() + is_first = True + async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") + if delta_time < 0.005: + receive_n += 1 + chunks.append(delta_time) + start_time = now_time + # print("messages", messages) + # print("text", text) + + else: + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = f"http://localhost:{port}/generate_stream" + + data = { + "inputs": rendered_prompt, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": output_len, + }, + } + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout) as session: + receive_n = 0 + text = "" async with session.post(url, headers=headers, json=data) as response: chunks = [] + start_time = time.time() + is_first = True async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + if delta_time < 0.005: + receive_n += 1 chunks.append(chunk) - output = b"".join(chunks).decode("utf-8") - output = json.loads(output) - - if "error" not in output: - break + text += json.loads(chunk.decode("utf-8")[5:])["token"]["text"] + start_time = now_time request_end_time = time.time() request_latency = request_end_time - request_start_time - REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) + REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) + + # Update progress bar if provided + if pbar: + pbar.update(1) async def benchmark( - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, + use_openai_api: bool = False, + concurrency: int = None, + port: int = 8080, ) -> None: - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request - task = asyncio.create_task(send_request(prompt, prompt_len, output_len)) - tasks.append(task) - await asyncio.gather(*tasks) + total_requests = len(input_requests) + + # Create progress bar + pbar = tqdm(total=total_requests, desc="Processing requests", unit="req") + + if concurrency is not None: + # Concurrency-based processing + semaphore = asyncio.Semaphore(concurrency) + tasks: List[asyncio.Task] = [] + + async def send_with_semaphore(messages, rendered_prompt, prompt_len, output_len): + async with semaphore: + await send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, port, pbar) + + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task(send_with_semaphore(messages, rendered_prompt, prompt_len, output_len)) + tasks.append(task) + + await asyncio.gather(*tasks) + else: + # Rate-based processing (original logic) + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task( + send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, port, pbar) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + # Close progress bar + pbar.close() def main(args: argparse.Namespace): @@ -176,28 +332,41 @@ def main(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) tokenizer = get_tokenizer(args.tokenizer, "slow") - input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + input_requests = sample_requests( + args.dataset, args.num_prompts, tokenizer, args.history_turns, args.max_total_tokens + ) benchmark_start_time = time.time() - asyncio.run(benchmark(input_requests, args.request_rate)) + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api, args.concurrency, args.port)) benchmark_end_time = time.time() benchmark_time = benchmark_end_time - benchmark_start_time print(f"Total time: {benchmark_time:.2f} s") print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") # Compute the latency statistics. - avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) + avg_latency = np.mean([latency for _, _, latency, _ in REQUEST_LATENCY]) print(f"Average latency: {avg_latency:.2f} s") - avg_per_token_latency = np.mean( - [latency / (prompt_len + output_len) for prompt_len, output_len, latency in REQUEST_LATENCY] + avg_time_to_first_token = np.mean([ttft for _, _, _, ttft in REQUEST_LATENCY]) + print("Average time to first token: " f"{avg_time_to_first_token:.2f} s") + avg_per_token_latency = ( + np.mean([latency / (prompt_len + output_len) for prompt_len, output_len, latency, _ in REQUEST_LATENCY]) * 1000 ) - print(f"Average latency per token: {avg_per_token_latency:.2f} s") - avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency in REQUEST_LATENCY]) - print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + print(f"Average latency per token: {avg_per_token_latency:.1f} ms") + # avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency, _ in REQUEST_LATENCY]) + # print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + avg_inter_token_latency = ( + np.mean( + [(latency - ttft) / (output_len - 1) for _, output_len, latency, ttft in REQUEST_LATENCY if output_len > 1] + ) + * 1000 + ) + print(f"Average inter-token latency: {avg_inter_token_latency:.1f} ms") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument("--use_openai_api", default=False, action="store_true", help="Use OpenAI API for requests.") + parser.add_argument("--port", type=int, default=8080, help="Port of the API server.") parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") parser.add_argument( @@ -209,7 +378,22 @@ def main(args: argparse.Namespace): "Otherwise, we use Poisson process to synthesize " "the request arrival times.", ) + parser.add_argument( + "--concurrency", + type=int, + default=None, + help="Number of concurrent requests to maintain. " "Cannot be used together with --request-rate.", + ) parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") + parser.add_argument( + "--history-turns", type=int, default=6, help="Max number of context turns before the target assistant reply." + ) + parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() + + # Validate that only one of request_rate or concurrency is set + if args.concurrency is not None and args.request_rate != float("inf"): + raise ValueError("Cannot set both --request-rate and --concurrency. Please use only one.") + main(args) From 7551fb75b45f0bbf8ff8ec19ab6fa3573dc3b2e9 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 4 Nov 2025 07:40:15 +0000 Subject: [PATCH 06/34] support pd split --- lightllm/server/api_cli.py | 3 +-- lightllm/server/api_start.py | 13 ++++++------- .../prefill_node_impl/prefill_impl_for_dp.py | 1 + .../prefill_node_impl/prefill_impl_for_dp.py | 1 + 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 932d4567f..d33b459c3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -570,7 +570,6 @@ def make_argument_parser() -> argparse.ArgumentParser: "--disable_dp_prompt_cache_fetch", action="store_true", default=False, - help="""Disable prefix prompt cache fetch for data parallel inference. - Enabled by default, but currently not supported for pd separated mode""", + help="""Disable prefix prompt cache fetch for data parallel inference, enabled by default.""", ) return parser diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 901ba7d4f..92cc98f5a 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -270,13 +270,12 @@ def normal_or_p_d_start(args): send_and_receive_node_ip(args) # 多机用于收发node ip # PD 分离模式下必须禁用 DP prompt cache fetch,且 dp 必须 > 1 - if not args.disable_dp_prompt_cache_fetch: - if args.run_mode != "normal" or args.dp <= 1: - args.disable_dp_prompt_cache_fetch = True - logger.warning( - """PD split mode or dp <= 1 does not support dp_prompt_cache_fetch; - overriding disable_dp_prompt_cache_fetch to True""" - ) + if not args.disable_dp_prompt_cache_fetch and args.dp <= 1: + args.disable_dp_prompt_cache_fetch = True + logger.warning( + """dp <= 1 does not support dp_prompt_cache_fetch; + overriding disable_dp_prompt_cache_fetch to True""" + ) set_env_start_args(args) logger.info(f"all start args:{args}") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py index 21999f6c7..f4e21e109 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py @@ -17,6 +17,7 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: def init_custom(self): ChunckedPrefillForPrefillNode.init_custom(self) + super().init_custom() return def _pre_handle_finished_reqs(self, finished_reqs): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py index 1aeddd267..02ba5b09d 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py @@ -18,6 +18,7 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: def init_custom(self): NIXLChunckedPrefillForPrefillNode.init_custom(self) + super().init_custom() return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: From a3c44be25811e32658690bea9d9fac5c4a8f650e Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 4 Nov 2025 10:35:31 +0000 Subject: [PATCH 07/34] add mem_queues --- .../pd_mode/prefill_node_impl/prefill_impl_for_dp.py | 4 ++-- .../pd_nixl/prefill_node_impl/prefill_impl_for_dp.py | 4 ++-- lightllm/server/router/model_infer/model_rpc.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py index f4e21e109..38d0f014c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py @@ -9,8 +9,8 @@ class DPChunkedForPrefillNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__(mem_queue=mem_queue) + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue, mem_queues: List[mp.Queue]) -> None: + super().__init__(mem_queue=mem_queue, mem_queues=mem_queues) self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py index 02ba5b09d..b9a7ec330 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py @@ -9,8 +9,8 @@ class NIXLDPChunkedForPrefillNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__(mem_queue=mem_queue) + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue, mem_queues: List[mp.Queue]) -> None: + super().__init__(mem_queue=mem_queue, mem_queues=mem_queues) self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 37e8a9454..134a8e9ad 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -129,12 +129,12 @@ def init_model(self, kvargs): if is_prefill_node: if self.args.dp > 1: - self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) + self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue, self.mem_queues) else: self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) elif is_nixl_prefill_node: if self.args.dp > 1: - self.backend = NIXLDPChunkedForPrefillNode(self.info_queue, self.mem_queue) + self.backend = NIXLDPChunkedForPrefillNode(self.info_queue, self.mem_queue, self.mem_queues) else: self.backend = NIXLChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) From 1323ce10c0f0cbc25b9537ebfa448fec5dde0d48 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Wed, 5 Nov 2025 04:42:08 +0000 Subject: [PATCH 08/34] little update --- .../mode_backend/dp_backend/impl.py | 16 +- .../mode_backend/dp_backend/p2p_fix.py | 155 ------------------ 2 files changed, 6 insertions(+), 165 deletions(-) delete mode 100644 lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bbb076e63..2722101bb 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -29,7 +29,7 @@ from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 128) +min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 512) class DPChunkedPrefillBackend(ModeBackend): @@ -82,13 +82,10 @@ def init_custom(self): if self.enable_dp_prompt_cache_fetch: torch.cuda.set_device(get_current_device_id()) - from .p2p_fix import reduce_tensor + from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - # 必须先设置 CUDA 设备上下文,因为序列化和反序列化 CUDA tensor 时 - # reduce_tensor 和 p2p_fix_rebuild_cuda_tensor 都会调用 torch.cuda.current_device() - for _ in range(self.node_world_size - 1): self.mem_queue.put(self.model.mem_manager) @@ -100,7 +97,6 @@ def init_custom(self): self.mem_managers.append(self.model.mem_manager) return - # 一些可以复用的通用功能函数 def _init_reqs(self, reqs: List[Tuple]): my_reqs = reqs other_reqs = [] @@ -124,8 +120,8 @@ def _match_radix_cache(self, shm_req): input_token_ids = shm_req.shm_prompt_ids.arr[0 : shm_req.input_len] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) - return share_node, kv_len, value_tensor + _, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + return kv_len, value_tensor def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = []): my_match = [] @@ -136,7 +132,7 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = continue shm_req = r.shm_req - _, kv_len, value_tensor = self._match_radix_cache(shm_req) + kv_len, value_tensor = self._match_radix_cache(shm_req) # only the first rank is ok if self.rank_in_dp == 0: with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): @@ -156,7 +152,7 @@ def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = shm_req.link_prompt_ids_shm_array() shm_req.link_kv_indexes_shm_array() - _, kv_len, value_tensor = self._match_radix_cache(shm_req) + kv_len, value_tensor = self._match_radix_cache(shm_req) with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): if kv_len > shm_req.dp_max_kv_len: shm_req.dp_max_kv_len = kv_len diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py deleted file mode 100644 index 62609c4c9..000000000 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/p2p_fix.py +++ /dev/null @@ -1,155 +0,0 @@ -# mypy: allow-untyped-defs -import multiprocessing -import os -import threading -from multiprocessing.reduction import ForkingPickler -from multiprocessing.util import register_after_fork -from typing import Union - -import torch -import torch.utils.hooks -from torch._namedtensor_internals import check_serializing_named_tensor -from torch.multiprocessing.reductions import storage_from_cache, shared_cache, StorageWeakRef -from torch.multiprocessing.reductions import reduce_nested_tensor, reduce_sparse_tensor, rebuild_tensor - - -def p2p_fix_rebuild_cuda_tensor( - tensor_cls, - tensor_size, - tensor_stride, - tensor_offset, - storage_cls, - dtype, - storage_device, - storage_handle, - storage_size_bytes, - storage_offset_bytes, - requires_grad, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, -): - # 因为接收进程在将 tensor 对应的 handle重新转化为指针的时候 - # 在其c++源码中会将当前显卡切换到storage_device再做操作,这样 - # 得到的指针可能不是接收进程当前上下文设备可以访问的,所以在这里 - # hack 修改了使用的 storage_device,这样后续tritonkernel同时 - # 访问几张显卡上的数据,进行p2p操作就不会出问题了。 - storage_device = torch.cuda.current_device() - # If storage_handle is None, storage points to nullptr. - if storage_handle is None or storage_size_bytes == 0: - storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) - else: - storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes)) - if storage is None: - torch.cuda._lazy_init() - storage = storage_cls._new_shared_cuda( - storage_device, - storage_handle, - storage_size_bytes, - storage_offset_bytes, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ) - shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage) - else: - # We already ref counting this Storage, but producer needs new ref-counters to be released. - storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) - - _storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage - - t = torch._utils._rebuild_tensor( - torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), - tensor_offset, - tensor_size, - tensor_stride, - ) - - if tensor_cls == torch.nn.parameter.Parameter: - # It is crucial for integer tensors to receive - # the requires_grad=False as an argument in the constructor - t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) - else: - t.requires_grad = requires_grad - - return t - - -def reduce_tensor(tensor): - if tensor.requires_grad and not tensor.is_leaf: - raise RuntimeError( - "Cowardly refusing to serialize non-leaf tensor which requires_grad, " - "since autograd does not support crossing process boundaries. " - "If you just want to transfer the data, call detach() on the tensor " - "before serializing (e.g., putting it on the queue)." - ) - - check_serializing_named_tensor(tensor) - torch.utils.hooks.warn_if_has_hooks(tensor) - - from torch.nested._internal.nested_tensor import NestedTensor - - if tensor.is_nested and not isinstance(tensor, NestedTensor): - return reduce_nested_tensor(tensor) - - if tensor.layout in { - torch.sparse_coo, - torch.sparse_csr, - torch.sparse_bsr, - torch.sparse_csc, - torch.sparse_bsc, - }: - return reduce_sparse_tensor(tensor) - - storage = tensor._typed_storage() - - if storage._untyped_storage.device.type == "cuda": - ( - device, - handle, - storage_size_bytes, - storage_offset_bytes, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ) = storage._share_cuda_() - tensor_offset = tensor.storage_offset() - shared_cache[handle] = StorageWeakRef(storage) - # _backward_hooks purposely omitted here, see - # Note [Don't serialize hooks] - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import ( - p2p_fix_rebuild_cuda_tensor, - ) - - return ( - p2p_fix_rebuild_cuda_tensor, - ( - type(tensor), - tensor.size(), - tensor.stride(), - tensor_offset, # tensor offset in its storage - type(storage), - tensor.dtype, - device, - handle, # identifier which CUDA allocation is the storage in. - storage_size_bytes, # size(in bytes) of the storage - storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation - tensor.requires_grad, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ), - ) - - # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] - metadata = ( - tensor.storage_offset(), - tensor.size(), - tensor.stride(), - tensor.requires_grad, - ) - return (rebuild_tensor, (type(tensor), storage, metadata)) From b72b7ac47863e96b311a00e74d96f07ccb278f0c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Dec 2025 05:43:13 +0000 Subject: [PATCH 09/34] fix --- lightllm/server/api_cli.py | 4 ++-- lightllm/server/api_start.py | 6 +++--- lightllm/server/core/objs/start_args_type.py | 2 +- .../router/model_infer/mode_backend/dp_backend/impl.py | 7 ++----- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d33b459c3..bf0e89887 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -567,9 +567,9 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Directory used to persist disk cache data. Defaults to a temp directory when not set.""", ) parser.add_argument( - "--disable_dp_prompt_cache_fetch", + "--enable_dp_prompt_cache_fetch", action="store_true", default=False, - help="""Disable prefix prompt cache fetch for data parallel inference, enabled by default.""", + help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) return parser diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 92cc98f5a..332a45283 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -270,11 +270,11 @@ def normal_or_p_d_start(args): send_and_receive_node_ip(args) # 多机用于收发node ip # PD 分离模式下必须禁用 DP prompt cache fetch,且 dp 必须 > 1 - if not args.disable_dp_prompt_cache_fetch and args.dp <= 1: - args.disable_dp_prompt_cache_fetch = True + if args.enable_dp_prompt_cache_fetch and args.dp <= 1: + args.enable_dp_prompt_cache_fetch = False logger.warning( """dp <= 1 does not support dp_prompt_cache_fetch; - overriding disable_dp_prompt_cache_fetch to True""" + overriding enable_dp_prompt_cache_fetch to False""" ) set_env_start_args(args) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0e186c256..71cafd6c4 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -114,7 +114,7 @@ class StartArgs: enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) - disable_dp_prompt_cache_fetch: bool = field(default=False) + enable_dp_prompt_cache_fetch: bool = field(default=False) # zmp ports router_port: int = field(default=None) detokenization_port: int = field(default=None) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 2722101bb..5bcdab94a 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -38,7 +38,7 @@ def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> No # 用于控制每一步是执行prefill 和 decode 还是跳过 self.control_state_machine = DPControlState(backend=self) - self.disable_dp_prompt_cache_fetch = get_env_start_args().disable_dp_prompt_cache_fetch + self.enable_dp_prompt_cache_fetch = get_env_start_args().enable_dp_prompt_cache_fetch self.min_trans_token_num = min_trans_token_num # 在 mtp 模式下切换绑定的prefill 和 decode 函数 @@ -76,9 +76,6 @@ def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> No return def init_custom(self): - self.enable_dp_prompt_cache_fetch = ( - not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and self.mem_queues is not None - ) if self.enable_dp_prompt_cache_fetch: torch.cuda.set_device(get_current_device_id()) @@ -120,7 +117,7 @@ def _match_radix_cache(self, shm_req): input_token_ids = shm_req.shm_prompt_ids.arr[0 : shm_req.input_len] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - _, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + _, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=False) return kv_len, value_tensor def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = []): From 518b8f1c3ea853ac18d0e0fddc2dc9a31656409e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Dec 2025 05:44:11 +0000 Subject: [PATCH 10/34] fix --- .../kv_cache_mem_manager/mem_manager.py | 20 +++++++++++++ lightllm/server/router/manager.py | 13 +++----- .../pd_mode/decode_node_impl/decode_impl.py | 3 +- .../decode_node_impl/decode_impl_for_dp.py | 4 +-- .../decode_node_impl/decode_infer_rpyc.py | 6 ++-- .../decode_kv_move_manager.py | 17 +++++------ .../decode_node_impl/decode_trans_obj.py | 3 +- .../decode_node_impl/decode_trans_process.py | 9 +++--- .../pd_mode/prefill_node_impl/prefill_impl.py | 3 +- .../prefill_node_impl/prefill_impl_for_dp.py | 4 +-- .../prefill_node_impl/prefill_infer_rpyc.py | 6 ++-- .../prefill_kv_move_manager.py | 17 +++++------ .../prefill_node_impl/prefill_trans_obj.py | 3 +- .../prefill_trans_process.py | 11 ++++--- .../mode_backend/dp_backend/impl.py | 21 ++++++------- .../pd_nixl/decode_node_impl/decode_impl.py | 3 +- .../decode_node_impl/decode_impl_for_dp.py | 4 +-- .../pd_nixl/prefill_node_impl/prefill_impl.py | 3 +- .../prefill_node_impl/prefill_impl_for_dp.py | 4 +-- .../server/router/model_infer/model_rpc.py | 30 +++++++------------ 20 files changed, 89 insertions(+), 95 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 3ffac99a5..97b520f5e 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -15,8 +15,11 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io +from lightllm.utils.shm_utils import create_or_link_shm +from multiprocessing.reduction import ForkingPickler logger = init_logger(__name__) +LIGHTLLM_MEM_MANAGER_SHM_SIZE = int(os.getenv("LIGHTLLM_MEM_MANAGER_SHM_SIZE", 1024 * 1024)) class MemoryManager: @@ -431,6 +434,23 @@ def copy_kv_from_other_dp_ranks( rank_in_dp=rank_in_dp, ) + def create_shm(self): + obj_bytes = ForkingPickler.dumps(self) + shm = create_or_link_shm( + f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}", LIGHTLLM_MEM_MANAGER_SHM_SIZE + ) + logger.info(f"create shm {shm.name} size {shm.size} obj size {len(obj_bytes)}") + shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") + shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes + + @staticmethod + def from_shm(rank_in_node): + shm = create_or_link_shm( + f"{get_unique_server_name()}_mem_manager_{rank_in_node}", LIGHTLLM_MEM_MANAGER_SHM_SIZE + ) + bytes_len = int.from_bytes(shm.buf[0:4], "little") + return ForkingPickler.loads(shm.buf[4 : 4 + bytes_len]) + class ReadOnlyStaticsMemoryManager: """ diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index b1d2805c2..89c46d9ed 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -116,9 +116,6 @@ async def wait_to_model_ready(self): self.model_rpc_servers = [] # 用于 kv move 管理进程 和 推理进程进行task信息的交互。 self.info_queue: mp.Queue = mp.Queue() - self.mem_queues: List[torch.multiprocessing.Queue] = [ - torch.multiprocessing.Queue() for _ in range(self.node_world_size) - ] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() @@ -137,9 +134,7 @@ async def wait_to_model_ready(self): rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, info_queue=self.info_queue, - mem_queue=self.mem_queues[(rank_id % node_world_size)], router_lock=self.router_lock, - mem_queues=self.mem_queues, ) ) tasks.append(task) @@ -206,14 +201,14 @@ async def wait_to_model_ready(self): start_prefill_kv_move_manager_process, ) - start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + start_prefill_kv_move_manager_process(self.args, self.info_queue) if self.args.run_mode == "nixl_prefill": from lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl import ( start_prefill_kv_move_manager_process, ) - start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + start_prefill_kv_move_manager_process(self.args, self.info_queue) if self.args.run_mode == "decode": # 启动 decode kv move 管理进程 @@ -221,14 +216,14 @@ async def wait_to_model_ready(self): start_decode_kv_move_manager_process, ) - start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + start_decode_kv_move_manager_process(self.args, self.info_queue) if self.args.run_mode == "nixl_decode": from lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl import ( start_decode_kv_move_manager_process, ) - start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + start_decode_kv_move_manager_process(self.args, self.info_queue) return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py index cb8d25358..f867d512d 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py @@ -19,10 +19,9 @@ class DecodeNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_strict_prefill = False def init_custom(self): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py index 2d637c817..8dc9ad1a6 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py @@ -9,8 +9,8 @@ class DPForDecodeNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__(mem_queue=mem_queue) + def __init__(self, info_queue: mp.Queue) -> None: + super().__init__() self.info_queue: mp.Queue = info_queue self.classed_req_strict_prefill = False return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py index 1e5cccb1f..202b624e5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py @@ -166,9 +166,9 @@ def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]): release_acquired_lock() return - def exposed_put_mem_manager_to_mem_queue(self): - self.backend.mem_queue.put(self.backend.model.mem_manager) - logger.info("put mem manager to info_queues ok") + def exposed_put_mem_manager_to_shm(self): + self.backend.model.mem_manager.create_shm() + logger.info("put mem manager to shm ok") return def exposed_unfrozen_time_out_reqs_tokens(self): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 8b818583c..88904018b 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -33,7 +33,7 @@ class DecodeKVMoveManager(rpyc.Service): - def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): + def __init__(self, args, info_queue: mp.Queue): super().__init__() self.args = args # args.dp // args.nnodes 在跨机tp的场景下,可能为0 @@ -44,7 +44,6 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): assert self.dp_world_size <= self.node_world_size self.info_queue = info_queue - self.mem_queues = mem_queues self.infer_rpyc_lock = threading.Lock() self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = [] @@ -87,7 +86,7 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # _put_kv_received_to_radix_cache # _fail_to_realese_forzen_tokens # _unfrozen_time_out_reqs_tokens - # _put_mem_manager_to_mem_queue + # _put_mem_manager_to_shm # 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放 # kv资源的接口 # ================================================================================== @@ -155,10 +154,10 @@ def _unfrozen_time_out_reqs_tokens(self) -> None: asyncio.run(self.wait_all_future_finish(futures)) return - def _put_mem_manager_to_mem_queue(self) -> None: + def _put_mem_manager_to_shm(self) -> None: with self.infer_rpyc_lock: for obj in self.infer_rpyc_objs: - obj.put_mem_manager_to_mem_queue() + obj.put_mem_manager_to_shm() return # ================================================================================== @@ -362,14 +361,14 @@ def remove_trans_obj(self, connect_id): return -def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): +def _init_env(args, info_queue: mp.Queue, event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_kv_move_manager") - manager = DecodeKVMoveManager(args, info_queue, mem_queues) + manager = DecodeKVMoveManager(args, info_queue) t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True}) threading.Thread(target=lambda: t.start(), daemon=True).start() @@ -381,9 +380,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. return -def start_decode_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): +def start_decode_kv_move_manager_process(args, info_queue: mp.Queue): event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, mem_queues, event)) + proc = mp.Process(target=_init_env, args=(args, info_queue, event)) proc.start() event.wait() assert proc.is_alive() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index cb6ed1939..8281cbe27 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -279,10 +279,9 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): device_id, self.task_in_queue, self.task_out_queue, - manager.mem_queues, ) assert self.task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() + manager._put_mem_manager_to_shm() assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 9d72bac6e..7965a9449 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -91,7 +91,7 @@ def async_connect(): logger.warning(f"error while connect to prefill node: {e}") -def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): +def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue): import os # os.environ["NCCL_DEBUG"] = "INFO" @@ -111,7 +111,9 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") - mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + # 从共享内存读取所有rank的mem_manager + node_world_size = args.tp // args.nnodes + mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank) for rank in range(node_world_size)] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} @@ -143,9 +145,8 @@ def start_decode_trans_process( device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], ): - proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues)) + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue)) proc.start() assert proc.is_alive() logger.info(f"decode trans kv process for device: {device_id} start!") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py index a2ff08bd2..441fc5cd8 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py @@ -20,11 +20,10 @@ class ChunckedPrefillForPrefillNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_no_decode = True def init_custom(self): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py index 38d0f014c..4e2c35153 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py @@ -9,8 +9,8 @@ class DPChunkedForPrefillNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue, mem_queues: List[mp.Queue]) -> None: - super().__init__(mem_queue=mem_queue, mem_queues=mem_queues) + def __init__(self, info_queue: mp.Queue) -> None: + super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py index 8ed16511e..12c6e9471 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py @@ -46,7 +46,7 @@ def exposed_remove_req_refs_from_prompt_cache(self, group_req_ids: List[int]): release_acquired_lock() return - def exposed_put_mem_manager_to_mem_queue(self): - self.backend.mem_queue.put(self.backend.model.mem_manager) - logger.info("put mem manager to mem_queue ok") + def exposed_put_mem_manager_to_shm(self): + self.backend.model.mem_manager.create_shm() + logger.info("put mem manager to shm ok") return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 78f262464..b0d1851c4 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -30,7 +30,7 @@ class PrefillKVMoveManager: - def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): + def __init__(self, args, info_queue: mp.Queue): self.args = args # args.dp // args.nnodes 在跨机tp的场景下,可能为0 self.dp_size_in_node = max(1, args.dp // args.nnodes) @@ -40,7 +40,6 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): assert self.dp_world_size <= self.node_world_size self.info_queue = info_queue - self.mem_queues = mem_queues self.infer_rpyc_objs: List[PDPrefillInferRpcServer] = [] from .prefill_trans_obj import KVTransConnectObj @@ -144,7 +143,7 @@ def check_trans_process_loop(self): # ================================================================================== # 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和 - # _put_mem_manager_to_mem_queue 都是通过 rpyc 与推理进程进行交互的接口 + # _put_mem_manager_to_shm 都是通过 rpyc 与推理进程进行交互的接口 # ================================================================================== def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): @@ -164,10 +163,10 @@ def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): asyncio.run(self.wait_all_future_finish(futures)) return - def _put_mem_manager_to_mem_queue(self): + def _put_mem_manager_to_shm(self): with self.infer_rpyc_lock: for obj in self.infer_rpyc_objs: - obj.put_mem_manager_to_mem_queue() + obj.put_mem_manager_to_shm() return async def wait_all_future_finish(self, futures: List[AsyncResult]): @@ -223,14 +222,14 @@ def __remove_dead_trans_obj(self): return -def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): +def _init_env(args, info_queue: mp.Queue, event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::prefill_kv_move_manager") - manager = PrefillKVMoveManager(args, info_queue, mem_queues) + manager = PrefillKVMoveManager(args, info_queue) kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) kv_trans_process_check.start() event.set() @@ -239,9 +238,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. return -def start_prefill_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): +def start_prefill_kv_move_manager_process(args, info_queue: mp.Queue): event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, mem_queues, event)) + proc = mp.Process(target=_init_env, args=(args, info_queue, event)) proc.start() event.wait() assert proc.is_alive() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index f53761e09..68b86895d 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -353,10 +353,9 @@ def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): device_id, self.task_in_queue, self.task_out_queue, - manager.mem_queues, ) assert self.task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() + manager._put_mem_manager_to_shm() assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 1972c47b4..84e3f5800 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -94,7 +94,6 @@ def _init_env( device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], ): import os @@ -116,7 +115,10 @@ def _init_env( host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30) ) task_out_queue.put("proc_start") - mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + + # 从共享内存读取所有rank的mem_manager + node_world_size = args.tp // args.nnodes + mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank) for rank in range(node_world_size)] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} @@ -150,11 +152,8 @@ def start_prefill_trans_process( device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], ): - proc = mp.Process( - target=_init_env, args=(args, store_ip, store_port, device_id, task_in_queue, task_out_queue, mem_queues) - ) + proc = mp.Process(target=_init_env, args=(args, store_ip, store_port, device_id, task_in_queue, task_out_queue)) proc.start() assert proc.is_alive() logger.info(f"prefill trans kv process for device: {device_id} started!") diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 5bcdab94a..60491d85f 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -33,7 +33,7 @@ class DPChunkedPrefillBackend(ModeBackend): - def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> None: + def __init__(self) -> None: super().__init__() # 用于控制每一步是执行prefill 和 decode 还是跳过 @@ -69,10 +69,6 @@ def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> No self.decode = self.decode_normal self.classed_req_strict_prefill = False - if not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and mem_queues is not None: - self._init_dp_cache_fetch(mem_queues) - self.mem_queues = mem_queues - self.mem_queue = mem_queue return def init_custom(self): @@ -83,13 +79,14 @@ def init_custom(self): mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - for _ in range(self.node_world_size - 1): - self.mem_queue.put(self.model.mem_manager) + # 每个rank创建自己的共享内存并写入mem_manager + self.model.mem_manager.create_shm() + # 读取所有rank的mem_manager self.mem_managers = [] - for queue_index in range(len(self.mem_queues)): - if queue_index != self.rank_in_node: - self.mem_managers.append(self.mem_queues[queue_index].get(timeout=60)) + for rank_idx in range(self.node_world_size): + if rank_idx != self.rank_in_node: + self.mem_managers.append(MemoryManager.from_shm(rank_idx)) else: self.mem_managers.append(self.model.mem_manager) return @@ -105,7 +102,7 @@ def _init_reqs(self, reqs: List[Tuple]): g_infer_state_lock.acquire() infer_reqs = g_infer_context.add_reqs(my_reqs, init_prefix_cache=not self.enable_dp_prompt_cache_fetch) if self.enable_dp_prompt_cache_fetch: - self._post_init_reqs(infer_reqs, other_reqs=other_reqs) + self._fetch_dp_prompt_cache(infer_reqs, other_reqs=other_reqs) for r in infer_reqs: r._match_radix_cache() @@ -120,7 +117,7 @@ def _match_radix_cache(self, shm_req): _, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=False) return kv_len, value_tensor - def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = []): + def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = []): my_match = [] other_match = [] # match all the reqs in this dp rank. diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index 1a16ae867..8529bf81b 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -12,10 +12,9 @@ class NIXLDecodeNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_strict_prefill = False def init_custom(self): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py index 758ebf50b..8bf0dd7c5 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py @@ -9,8 +9,8 @@ class NIXLDPForDecodeNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: - super().__init__(mem_queue=mem_queue) + def __init__(self, info_queue: mp.Queue) -> None: + super().__init__() self.info_queue: mp.Queue = info_queue self.classed_req_strict_prefill = False return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py index 0d5a6de79..e9e61571a 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py @@ -12,11 +12,10 @@ class NIXLChunckedPrefillForPrefillNode(ChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + def __init__(self, info_queue: mp.Queue) -> None: super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue - self.mem_queue: mp.Queue = mem_queue self.classed_req_no_decode = True self.nixl_prefill_chuncked_handle_func = self._prefill_chuncked_handle_func diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py index b9a7ec330..2c6c295bc 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py @@ -9,8 +9,8 @@ class NIXLDPChunkedForPrefillNode(DPChunkedPrefillBackend): - def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue, mem_queues: List[mp.Queue]) -> None: - super().__init__(mem_queue=mem_queue, mem_queues=mem_queues) + def __init__(self, info_queue: mp.Queue) -> None: + super().__init__() self.support_overlap = False self.info_queue: mp.Queue = info_queue self.classed_req_no_decode = True diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 134a8e9ad..55fe7a415 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -47,15 +47,11 @@ def __init__( rpc_event: multiprocessing.Event, rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, - mem_queue: mp.Queue, - mem_queues: List[mp.Queue] = None, ): super().__init__() self.args: StartArgs = args self.node_world_size = node_world_size self.info_queue = info_queue - self.mem_queue = mem_queue - self.mem_queues = mem_queues self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -129,29 +125,29 @@ def init_model(self, kvargs): if is_prefill_node: if self.args.dp > 1: - self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue, self.mem_queues) + self.backend = DPChunkedForPrefillNode(self.info_queue) else: - self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + self.backend = ChunckedPrefillForPrefillNode(self.info_queue) elif is_nixl_prefill_node: if self.args.dp > 1: - self.backend = NIXLDPChunkedForPrefillNode(self.info_queue, self.mem_queue, self.mem_queues) + self.backend = NIXLDPChunkedForPrefillNode(self.info_queue) else: - self.backend = NIXLChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + self.backend = NIXLChunckedPrefillForPrefillNode(self.info_queue) elif is_decode_node: if self.args.dp > 1: - self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) + self.backend = DPForDecodeNode(self.info_queue) else: - self.backend = DecodeNode(self.info_queue, self.mem_queue) + self.backend = DecodeNode(self.info_queue) elif is_nixl_decode_node: if self.args.dp > 1: - self.backend = NIXLDPForDecodeNode(self.info_queue, self.mem_queue) + self.backend = NIXLDPForDecodeNode(self.info_queue) else: - self.backend = NIXLDecodeNode(self.info_queue, self.mem_queue) + self.backend = NIXLDecodeNode(self.info_queue) elif self.args.dp > 1: - self.backend = DPChunkedPrefillBackend(mem_queue=self.mem_queue, mem_queues=self.mem_queues) + self.backend = DPChunkedPrefillBackend() elif use_reward_model: self.backend = RewardModelBackend() elif return_all_prompt_logprobs: @@ -220,12 +216,10 @@ def _init_env( rank_in_node, node_world_size, info_queue, - mem_queue, router_lock, rpc_event: mp.Event, rpc_finished_event: mp.Event, success_event: mp.Event, - mem_queues: List[mp.Queue] = None, ): import lightllm.utils.rpyc_fix_utils as _ @@ -240,7 +234,7 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue, mem_queues + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue ) success_event.set() @@ -256,9 +250,7 @@ async def start_model_process( rpc_event, rpc_finished_event, info_queue: mp.Queue, - mem_queue: mp.Queue, router_lock: mp.Queue, - mem_queues: List[mp.Queue] = None, ): import lightllm.utils.rpyc_fix_utils as _ @@ -271,12 +263,10 @@ async def start_model_process( rank_in_node, node_world_size, info_queue, - mem_queue, router_lock, rpc_event, rpc_finished_event, success_event, - mem_queues, ), ) proc.start() From da2eb3d1879982aaa3eefb13bfb79c552a933220 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Thu, 6 Nov 2025 08:37:02 +0000 Subject: [PATCH 11/34] use node_nccl_group --- .../router/model_infer/mode_backend/dp_backend/impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 60491d85f..c3c9bba43 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -79,10 +79,10 @@ def init_custom(self): mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - # 每个rank创建自己的共享内存并写入mem_manager self.model.mem_manager.create_shm() - # 读取所有rank的mem_manager + dist.barrier(group=self.node_nccl_group) + self.mem_managers = [] for rank_idx in range(self.node_world_size): if rank_idx != self.rank_in_node: @@ -154,7 +154,7 @@ def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tu other_match.append((shm_req, kv_len, value_tensor)) # wait all the ranks to finish the match - dist.barrier() + dist.barrier(group=self.node_nccl_group) # Copy the kv_indexes of this dp rank to other required req for match in other_match: @@ -165,7 +165,7 @@ def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tu self.release_all_shm_reqs([match[0] for match in other_match]) # wait all the ranks to finish the copy - dist.barrier() + dist.barrier(group=self.node_nccl_group) # Perform a kv transfer, get all indexes and the corresponding dp_rank move_token_indexes = [] From 737218e948dcfd8551b0879f91815a6081c89750 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Fri, 7 Nov 2025 09:15:11 +0000 Subject: [PATCH 12/34] delete shm_kv_indexes add shared_kv_indexes to reduce shared memory usage --- .../kv_cache_mem_manager/mem_manager.py | 2 +- lightllm/server/core/objs/req.py | 24 +--- .../mode_backend/dp_backend/impl.py | 113 +++++++++++++----- 3 files changed, 85 insertions(+), 54 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 97b520f5e..31296a328 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -19,7 +19,7 @@ from multiprocessing.reduction import ForkingPickler logger = init_logger(__name__) -LIGHTLLM_MEM_MANAGER_SHM_SIZE = int(os.getenv("LIGHTLLM_MEM_MANAGER_SHM_SIZE", 1024 * 1024)) +LIGHTLLM_MEM_MANAGER_SHM_SIZE = int(os.getenv("LIGHTLLM_MEM_MANAGER_SHM_SIZE", 16 * 1024)) class MemoryManager: diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 61af36e5c..6ee68c52a 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -126,6 +126,8 @@ class Req(ctypes.Structure): ("dp_max_kv_len", ctypes.c_int), # 拥有最大kv cache长度的dp_rank ("dp_max_kv_rank", ctypes.c_int), + # 原DP的kv len + ("dp_origin_kv_len", ctypes.c_int), ] def get_str(self): @@ -177,7 +179,6 @@ def init( self.alloc_shm_numpy_len = self.input_len + self.sample_params.max_new_tokens + 1024 # + 1024 for safe self.create_logprobs_shm_array() self.create_prompt_ids_shm_array() - self.create_kv_indexes_shm_array() self.chunked_prefill_size = chunked_prefill_size self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids self.mtp_accepted_token_num = 0 @@ -191,6 +192,7 @@ def init( # 初始化DP模式相关字段 self.dp_max_kv_len = 0 self.dp_max_kv_rank = -1 + self.dp_origin_kv_len = 0 if get_env_start_args().enable_cpu_cache: self._fill_input_token_hash() return @@ -235,32 +237,12 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return - def create_kv_indexes_shm_array(self): - service_uni_name = get_unique_server_name() - name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}" - self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64) - self.shm_kv_indexes.create_shm() - return - - def link_kv_indexes_shm_array(self): - service_uni_name = get_unique_server_name() - name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}" - self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64) - self.shm_kv_indexes.link_shm() - return - def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() def get_prompt_ids_numpy(self): return self.shm_prompt_ids.arr[: self.input_len] - def get_kv_indexes(self): - return self.shm_kv_indexes.arr[: self.input_len].tolist() - - def get_kv_indexes_numpy(self): - return self.shm_kv_indexes.arr[: self.input_len] - def to_router_rpc_obj(self): if hasattr(self, "multimodal_params"): return ( diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c3c9bba43..8e0661942 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -29,7 +29,8 @@ from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 512) +min_trans_token_num = os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", 512) +dp_kv_transfer_req_num = os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", 16) class DPChunkedPrefillBackend(ModeBackend): @@ -76,19 +77,39 @@ def init_custom(self): torch.cuda.set_device(get_current_device_id()) from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor + from lightllm.server.core.objs.shm_array import ShmArray + from lightllm.utils.envs_utils import get_unique_server_name mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ + # Create shared memory for mem_manager self.model.mem_manager.create_shm() + # Create shared ShmArray for kv_indexes transfer + # Use a small buffer to save shared memory + self.dp_kv_transfer_req_num = dp_kv_transfer_req_num + max_len = get_env_start_args().max_req_total_len + 8 + self.shared_kv_indexes_name = f"{get_unique_server_name()}_shared_kv_indexes_global" + self.shared_kv_indexes = ShmArray( + self.shared_kv_indexes_name, (self.dp_kv_transfer_req_num, max_len), dtype=np.int64 + ) + # Only rank_in_node == 0 creates the shared memory + if self.rank_in_node == 0: + self.shared_kv_indexes.create_shm() + dist.barrier(group=self.node_nccl_group) + # Collect mem_managers from all ranks self.mem_managers = [] for rank_idx in range(self.node_world_size): if rank_idx != self.rank_in_node: self.mem_managers.append(MemoryManager.from_shm(rank_idx)) else: self.mem_managers.append(self.model.mem_manager) + + # Other ranks link to the shared memory + if self.rank_in_node != 0: + self.shared_kv_indexes.link_shm() return def _init_reqs(self, reqs: List[Tuple]): @@ -102,7 +123,7 @@ def _init_reqs(self, reqs: List[Tuple]): g_infer_state_lock.acquire() infer_reqs = g_infer_context.add_reqs(my_reqs, init_prefix_cache=not self.enable_dp_prompt_cache_fetch) if self.enable_dp_prompt_cache_fetch: - self._fetch_dp_prompt_cache(infer_reqs, other_reqs=other_reqs) + self._fetch_dp_prompt_cache(infer_reqs, other_reqs=other_reqs, origin_reqs=reqs) for r in infer_reqs: r._match_radix_cache() @@ -117,7 +138,10 @@ def _match_radix_cache(self, shm_req): _, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=False) return kv_len, value_tensor - def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = []): + def _fetch_dp_prompt_cache( + self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = [], origin_reqs: List[Tuple] = [] + ): + # shm_index_2_index = {r[1]: i for i, r in enumerate(origin_reqs)} my_match = [] other_match = [] # match all the reqs in this dp rank. @@ -130,6 +154,7 @@ def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tu # only the first rank is ok if self.rank_in_dp == 0: with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): + shm_req.dp_origin_kv_len = kv_len if kv_len > shm_req.dp_max_kv_len: shm_req.dp_max_kv_len = kv_len shm_req.dp_max_kv_rank = self.dp_rank_in_node # 单机 @@ -156,29 +181,61 @@ def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tu # wait all the ranks to finish the match dist.barrier(group=self.node_nccl_group) - # Copy the kv_indexes of this dp rank to other required req - for match in other_match: + # 创建 shm_index 到匹配结果的映射 + shm_index_to_match = {match[0].index_in_shm_mem: match for match in my_match} + shm_index_to_match.update({match[0].index_in_shm_mem: match for match in other_match}) + + my_trans_match = [] + other_trans_match = [] + transfer_count = 0 + for r in origin_reqs: + _, shm_index, _, suggested_dp_index = r + shm_req, kv_len, value_tensor = shm_index_to_match[shm_index] + match = (shm_req, kv_len, value_tensor, suggested_dp_index) + + if suggested_dp_index != shm_req.dp_max_kv_rank: + if suggested_dp_index == self.dp_rank_in_node: + my_trans_match.append((match, transfer_count)) + else: + other_trans_match.append((match, transfer_count)) + transfer_count += 1 + + if transfer_count == self.dp_kv_transfer_req_num: + self._transfer_dp_kv_cache(my_trans_match, other_trans_match) + my_trans_match = [] + other_trans_match = [] + transfer_count = 0 + + if transfer_count > 0: + self._transfer_dp_kv_cache(my_trans_match, other_trans_match) + + def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple]): + other_shm_reqs = [] + for match, index in other_match: shm_req, kv_len, value_tensor = match + trans_len = kv_len - shm_req.dp_origin_kv_len if shm_req.dp_max_kv_rank == self.dp_rank_in_node: - shm_req.shm_kv_indexes.arr[0:kv_len] = value_tensor - # release other dp_rank's shm_req - self.release_all_shm_reqs([match[0] for match in other_match]) + self.shared_kv_indexes.arr[index, 0:trans_len] = value_tensor[shm_req.dp_origin_kv_len : kv_len] + other_shm_reqs.append(shm_req) - # wait all the ranks to finish the copy + self.release_all_shm_reqs(other_shm_reqs) dist.barrier(group=self.node_nccl_group) - # Perform a kv transfer, get all indexes and the corresponding dp_rank + if not my_match: + return + move_token_indexes = [] token_dp_indexes = [] + trans_info = [] alloc_size = 0 - for r in my_match: - shm_req, kv_len, value_tensor = r + for match, index in my_match: + shm_req, kv_len, value_tensor = match trans_len = shm_req.dp_max_kv_len - kv_len if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node: - # Only copy kv_indexes that are not in this dp rank. - move_token_indexes.extend(shm_req.shm_kv_indexes.arr[kv_len : shm_req.dp_max_kv_len]) - token_dp_indexes.extend([shm_req.dp_max_kv_rank for _ in range(trans_len)]) - alloc_size += trans_len + move_token_indexes.extend(self.shared_kv_indexes.arr[index, 0:trans_len]) + token_dp_indexes.extend([shm_req.dp_max_kv_rank] * trans_len) + trans_info.append((shm_req, kv_len, value_tensor, trans_len)) + alloc_size += trans_len if alloc_size < self.min_trans_token_num: return @@ -189,7 +246,6 @@ def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tu move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") - # transfer kv self.model.mem_manager.copy_kv_from_other_dp_ranks( mem_managers=self.mem_managers, move_token_indexes=move_token_indexes, @@ -200,22 +256,15 @@ def _fetch_dp_prompt_cache(self, infer_reqs: List[InferReq], other_reqs: List[Tu ) self.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {alloc_size}") - # 更新radix cache start_index = 0 - for r in my_match: - shm_req, kv_len, value_tensor = r - trans_len = shm_req.dp_max_kv_len - kv_len - if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node: - new_value_tensor = mem_indexes[start_index : start_index + trans_len].cpu() - start_index += trans_len - key = torch.tensor( - shm_req.shm_prompt_ids.arr[0 : shm_req.dp_max_kv_len], dtype=torch.int64, device="cpu" - ) - if value_tensor is not None: - value_tensor = torch.cat((value_tensor, new_value_tensor), dim=0) - else: - value_tensor = new_value_tensor - g_infer_context.radix_cache.insert(key, value_tensor) + for shm_req, kv_len, value_tensor, trans_len in trans_info: + new_value_tensor = mem_indexes[start_index : start_index + trans_len].cpu() + start_index += trans_len + key = torch.tensor(shm_req.shm_prompt_ids.arr[0 : shm_req.dp_max_kv_len], dtype=torch.int64, device="cpu") + value_tensor = ( + torch.cat((value_tensor, new_value_tensor), dim=0) if value_tensor is not None else new_value_tensor + ) + g_infer_context.radix_cache.insert(key, value_tensor) def release_all_shm_reqs(self, shm_reqs): for shm_req in shm_reqs: From b0743a7677b5824ad4e84cda3b7ffd69d9414c67 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Mon, 10 Nov 2025 03:40:00 +0000 Subject: [PATCH 13/34] layer into triton op --- .../kv_cache_mem_manager/mem_manager.py | 40 ++++++++-------- .../common/kv_trans_kernel/kv_trans_v2.py | 29 +++++++---- .../mode_backend/dp_backend/impl.py | 9 ++-- lightllm/utils/log_utils.py | 2 +- .../kv_trans_kernel/test_kv_trans_v2.py | 48 ++++++++++++++++++- 5 files changed, 93 insertions(+), 35 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 31296a328..6bb75cc4c 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -414,25 +414,23 @@ def copy_kv_from_other_dp_ranks( dp_size_in_node: int, rank_in_dp: int, ): - if not hasattr(self, "mem_ptrs_dict"): - self.mem_ptrs_dict = {} - for layer_index in range(self.layer_num): - mems_ptr = [] - for i in range(0, len(mem_managers)): - mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) - mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") - self.mem_ptrs_dict[layer_index] = mems_ptr - - for layer_index in range(self.layer_num): - kv_trans_for_dp( - input_mems=self.mem_ptrs_dict[layer_index], - input_idx=move_token_indexes, - input_dp_idx=token_dp_indexes, - output=self.kv_buffer[layer_index], - output_idx=mem_indexes, - dp_size_in_node=dp_size_in_node, - rank_in_dp=rank_in_dp, - ) + if not hasattr(self, "mem_ptrs_tensor"): + # 构建一个2D tensor,shape为(layer_num, mem_num) + mems_ptr_list = [] + for i in range(0, len(mem_managers)): + mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr()) + self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cuda") + + # 一次性传输所有层 + kv_trans_for_dp( + input_mems=self.mem_ptrs_tensor, + input_idx=move_token_indexes, + input_dp_idx=token_dp_indexes, + output=self.kv_buffer, + output_idx=mem_indexes, + dp_size_in_node=dp_size_in_node, + rank_in_dp=rank_in_dp, + ) def create_shm(self): obj_bytes = ForkingPickler.dumps(self) @@ -449,7 +447,9 @@ def from_shm(rank_in_node): f"{get_unique_server_name()}_mem_manager_{rank_in_node}", LIGHTLLM_MEM_MANAGER_SHM_SIZE ) bytes_len = int.from_bytes(shm.buf[0:4], "little") - return ForkingPickler.loads(shm.buf[4 : 4 + bytes_len]) + obj_bytes = shm.buf[4 : 4 + bytes_len].tobytes() + shm.close() + return ForkingPickler.loads(obj_bytes) class ReadOnlyStaticsMemoryManager: diff --git a/lightllm/common/kv_trans_kernel/kv_trans_v2.py b/lightllm/common/kv_trans_kernel/kv_trans_v2.py index 912587e1e..5ee058225 100644 --- a/lightllm/common/kv_trans_kernel/kv_trans_v2.py +++ b/lightllm/common/kv_trans_kernel/kv_trans_v2.py @@ -199,13 +199,16 @@ def _kv_trans_for_dp_kernel( input_stride_0, input_stride_1, input_stride_2, + input_stride_3, input_token_idx_ptr, input_token_dp_index_ptr, output_ptr, output_stride_0, output_stride_1, output_stride_2, + output_stride_3, output_token_idx_ptr, + layer_num: tl.constexpr, token_num: int, head_num: int, head_dim: int, @@ -229,11 +232,20 @@ def _kv_trans_for_dp_kernel( mem_index = RANK_IN_DP + dp_index * CARD_NUM_PER_D input_token_idx = tl.load(input_token_idx_ptr + tid) output_token_idx = tl.load(output_token_idx_ptr + tid) - for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES): - cur_offs = block_idx * BLOCK_SIZE + offs - input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty)) - in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim) - tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim) + + input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty)) + for layer_idx in tl.range(0, layer_num, 1): + for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES): + cur_offs = block_idx * BLOCK_SIZE + offs + in_datas = tl.load( + input_ptr + input_stride_0 * layer_idx + input_stride_1 * input_token_idx + cur_offs, + mask=cur_offs < head_num_dim, + ) + tl.store( + output_ptr + output_stride_0 * layer_idx + output_stride_1 * output_token_idx + cur_offs, + in_datas, + mask=cur_offs < head_num_dim, + ) tid += grid_count @@ -250,19 +262,19 @@ def kv_trans_for_dp( rank_in_dp: int, ): """ - input_mems 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。 + input_mems 是一个 torch.uint64 的tensor, shape为(layer_num, mem_num),其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。 """ assert input_mems.is_contiguous() assert output.is_contiguous() assert len(input_mems.shape) == 1 - assert len(output.shape) == 3 + assert len(output.shape) == 4 assert len(input_idx) == len(output_idx) assert len(output_idx) == len(input_dp_idx) assert len(input_mems) % dp_size_in_node == 0 card_num_per_d = len(input_mems) // dp_size_in_node - _, head_num, head_dim = output.shape + layer_num, _, head_num, head_dim = output.shape token_num = len(output_idx) # 用较少的资源来做数据传输,防止占用过多的 sm 计算单元 grid_count = 20 @@ -278,6 +290,7 @@ def kv_trans_for_dp( output, *output.stride(), output_idx, + layer_num=layer_num, token_num=token_num, head_num=head_num, head_dim=head_dim, diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 8e0661942..94e0616d7 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -29,8 +29,8 @@ from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -min_trans_token_num = os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", 512) -dp_kv_transfer_req_num = os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", 16) +min_trans_token_num = int(os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", "512")) +dp_kv_transfer_req_num = int(os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", "16")) class DPChunkedPrefillBackend(ModeBackend): @@ -169,7 +169,6 @@ def _fetch_dp_prompt_cache( if sampling_param.disable_prompt_cache: continue shm_req.link_prompt_ids_shm_array() - shm_req.link_kv_indexes_shm_array() kv_len, value_tensor = self._match_radix_cache(shm_req) with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): @@ -212,7 +211,7 @@ def _fetch_dp_prompt_cache( def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple]): other_shm_reqs = [] for match, index in other_match: - shm_req, kv_len, value_tensor = match + shm_req, kv_len, value_tensor, _ = match trans_len = kv_len - shm_req.dp_origin_kv_len if shm_req.dp_max_kv_rank == self.dp_rank_in_node: self.shared_kv_indexes.arr[index, 0:trans_len] = value_tensor[shm_req.dp_origin_kv_len : kv_len] @@ -229,7 +228,7 @@ def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple]) trans_info = [] alloc_size = 0 for match, index in my_match: - shm_req, kv_len, value_tensor = match + shm_req, kv_len, value_tensor, _ = match trans_len = shm_req.dp_max_kv_len - kv_len if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node: move_token_indexes.extend(self.shared_kv_indexes.arr[index, 0:trans_len]) diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index f15309d5c..ce4588f52 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -7,7 +7,7 @@ import time from typing import Optional -_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" +_FORMAT = "%(levelname)s %(asctime)s,%(msecs)03d [%(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" _LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug") diff --git a/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py b/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py index 509415da0..dd34402dd 100644 --- a/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py +++ b/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py @@ -1,7 +1,7 @@ import pytest import torch import random -from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_p_node, kv_trans_v2_for_d_node +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_p_node, kv_trans_v2_for_d_node, kv_trans_for_dp @pytest.mark.parametrize( @@ -73,5 +73,51 @@ def test_kv_trans_v2_for_d_node(token_num): return +@pytest.mark.parametrize( + "token_num", + [token_num for token_num in range(5, 10)], +) +def test_kv_trans_for_dp(token_num): + card_num = 8 + dp_size_in_node = 4 + layer_num = 3 + head_num = 2 + head_dim = 512 + kv_buffer_token_num = 512 + rank_in_dp = 1 + + card_num_per_d = card_num // dp_size_in_node + + # 创建多层的 mem,每个 mem 包含所有层的数据 + mems = [] + for _ in range(card_num): + mems.append( + torch.randn((layer_num, kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda") + ) + + input_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda") + input_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)] + input_idx = torch.tensor(input_idx, dtype=torch.int32, device="cuda") + input_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)] + input_dp_idx = torch.tensor(input_dp_idx, dtype=torch.int32, device="cuda") + + true_output = torch.zeros((layer_num, kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda") + test_output = torch.zeros((layer_num, kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda") + output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda") + + kv_trans_for_dp(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node, rank_in_dp) + + # 验证结果 + for dest_token_index, src_token_index, dp_index in zip( + list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy() + ): + mem_index = rank_in_dp + dp_index * card_num_per_d + # 所有 layer 都从同一个 mem 的对应层读取 + true_output[:, dest_token_index, :, :] = mems[mem_index][:, src_token_index, :, :] + + assert torch.equal(true_output, test_output), "kv_trans_for_dp output mismatch" + return + + if __name__ == "__main__": pytest.main() From 26879eab36fa735ae65578db8edc08b1e33b742d Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Mon, 10 Nov 2025 08:32:57 +0000 Subject: [PATCH 14/34] fix multiple visits to fd --- .../kv_cache_mem_manager/mem_manager.py | 23 +++++++++++-------- .../server/router/model_infer/infer_batch.py | 1 - .../decode_node_impl/decode_trans_obj.py | 3 ++- .../decode_node_impl/decode_trans_process.py | 2 +- .../prefill_node_impl/prefill_trans_obj.py | 3 ++- .../prefill_trans_process.py | 2 +- .../mode_backend/dp_backend/impl.py | 2 +- 7 files changed, 20 insertions(+), 16 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 6bb75cc4c..be14f2d87 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -9,7 +9,7 @@ from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans -from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id @@ -433,18 +433,21 @@ def copy_kv_from_other_dp_ranks( ) def create_shm(self): - obj_bytes = ForkingPickler.dumps(self) - shm = create_or_link_shm( - f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}", LIGHTLLM_MEM_MANAGER_SHM_SIZE - ) - logger.info(f"create shm {shm.name} size {shm.size} obj size {len(obj_bytes)}") - shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") - shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes + for rank_in_node in range(0, get_node_world_size()): + obj_bytes = ForkingPickler.dumps(self) + shm = create_or_link_shm( + f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}_{rank_in_node}", + LIGHTLLM_MEM_MANAGER_SHM_SIZE, + ) + logger.info(f"create shm {shm.name} size {shm.size} obj size {len(obj_bytes)}") + shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") + shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes @staticmethod - def from_shm(rank_in_node): + def from_shm(mem_manager_rank_in_node, rank_in_node): shm = create_or_link_shm( - f"{get_unique_server_name()}_mem_manager_{rank_in_node}", LIGHTLLM_MEM_MANAGER_SHM_SIZE + f"{get_unique_server_name()}_mem_manager_{mem_manager_rank_in_node}_{rank_in_node}", + LIGHTLLM_MEM_MANAGER_SHM_SIZE, ) bytes_len = int.from_bytes(shm.buf[0:4], "little") obj_bytes = shm.buf[4 : 4 + bytes_len].tobytes() diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f774fa3d9..ab2965887 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -361,7 +361,6 @@ def _init_all_state(self): self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() - self.shm_req.link_kv_indexes_shm_array() self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) # 更新 nixl pd 分离模式下, prefill 节点需要开始传输的起始位置 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index 8281cbe27..f462fbb3d 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -281,7 +281,8 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): self.task_out_queue, ) assert self.task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_shm() + if self.device_id == 0: + manager._put_mem_manager_to_shm() assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 7965a9449..1a03a1021 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -113,7 +113,7 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank) for rank in range(node_world_size)] + mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index 68b86895d..237483db4 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -355,7 +355,8 @@ def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): self.task_out_queue, ) assert self.task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_shm() + if self.device_id == 0: + manager._put_mem_manager_to_shm() assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 84e3f5800..8e4f9b04e 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -118,7 +118,7 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank) for rank in range(node_world_size)] + mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 94e0616d7..bbf4914f3 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -103,7 +103,7 @@ def init_custom(self): self.mem_managers = [] for rank_idx in range(self.node_world_size): if rank_idx != self.rank_in_node: - self.mem_managers.append(MemoryManager.from_shm(rank_idx)) + self.mem_managers.append(MemoryManager.from_shm(rank_idx, self.rank_in_node)) else: self.mem_managers.append(self.model.mem_manager) From 725eec39c42c6465f4e00c427425b140038109d7 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 11 Nov 2025 04:44:59 +0000 Subject: [PATCH 15/34] fix pd mem_manager get failed --- .../pd_mode/decode_node_impl/decode_trans_obj.py | 3 +++ .../pd_mode/decode_node_impl/decode_trans_process.py | 5 ++++- .../pd_mode/prefill_node_impl/prefill_trans_obj.py | 1 + .../pd_mode/prefill_node_impl/prefill_trans_process.py | 3 +++ 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index f462fbb3d..575032189 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -281,8 +281,11 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): self.task_out_queue, ) assert self.task_out_queue.get(timeout=30) == "proc_start" + # 确保在子进程读取共享内存之前,主进程已经将 mem_manager 写入共享内存 if self.device_id == 0: manager._put_mem_manager_to_shm() + # 通知子进程可以从共享内存读取 mem_manager + self.task_in_queue.put("mem_managers_ready") assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 1a03a1021..84c669730 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -111,6 +111,9 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") + # 等待主进程将 mem_manager 写入共享内存后的信号 + assert task_in_queue.get(timeout=60) == "mem_managers_ready" + # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] @@ -136,7 +139,7 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. logger.warning(f"unexpected task type: {task}") except Exception as e: - logger.error(f"Fatal error happened in kv trans process: {e}") + logger.error(f"Fatal error happened in kv trans process: {e} in device {device_id}") raise diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index 237483db4..cdfacaecc 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -357,6 +357,7 @@ def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): assert self.task_out_queue.get(timeout=30) == "proc_start" if self.device_id == 0: manager._put_mem_manager_to_shm() + self.task_in_queue.put("mem_managers_ready") assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 8e4f9b04e..deffcbaf3 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -116,6 +116,9 @@ def _init_env( ) task_out_queue.put("proc_start") + # 等待主进程将 mem_manager 写入共享内存后的信号 + assert task_in_queue.get(timeout=60) == "mem_managers_ready" + # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] From dad0b83e47642ded66985da61f0ae719ae9b1ec6 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 11 Nov 2025 05:28:05 +0000 Subject: [PATCH 16/34] fix release other shm_reqs --- .../router/model_infer/mode_backend/dp_backend/impl.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bbf4914f3..4fc15aa0f 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -161,10 +161,12 @@ def _fetch_dp_prompt_cache( my_match.append((shm_req, kv_len, value_tensor)) # match all the reqs in other dp ranks. + other_shm_reqs = [] if self.rank_in_dp == 0: for r in other_reqs: _, shm_index, _, _ = r shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(shm_index) + other_shm_reqs.append(shm_req) sampling_param = InferSamplingParams(shm_req, g_infer_context.vocab_size) if sampling_param.disable_prompt_cache: continue @@ -192,9 +194,12 @@ def _fetch_dp_prompt_cache( shm_req, kv_len, value_tensor = shm_index_to_match[shm_index] match = (shm_req, kv_len, value_tensor, suggested_dp_index) + # 需要传输的 if suggested_dp_index != shm_req.dp_max_kv_rank: + # 需要获取的 if suggested_dp_index == self.dp_rank_in_node: my_trans_match.append((match, transfer_count)) + # 需要给其他dp的 else: other_trans_match.append((match, transfer_count)) transfer_count += 1 @@ -208,16 +213,15 @@ def _fetch_dp_prompt_cache( if transfer_count > 0: self._transfer_dp_kv_cache(my_trans_match, other_trans_match) + self.release_all_shm_reqs(other_shm_reqs) + def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple]): - other_shm_reqs = [] for match, index in other_match: shm_req, kv_len, value_tensor, _ = match trans_len = kv_len - shm_req.dp_origin_kv_len if shm_req.dp_max_kv_rank == self.dp_rank_in_node: self.shared_kv_indexes.arr[index, 0:trans_len] = value_tensor[shm_req.dp_origin_kv_len : kv_len] - other_shm_reqs.append(shm_req) - self.release_all_shm_reqs(other_shm_reqs) dist.barrier(group=self.node_nccl_group) if not my_match: From 6cc9982f94eaa27bc02f6604030e0a503045b32e Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 11 Nov 2025 06:06:19 +0000 Subject: [PATCH 17/34] add use_for_pd_trans to avoid duplicate name overwriting --- .../kv_cache_mem_manager/mem_manager.py | 21 ++++++++++++------- .../mode_backend/dp_backend/impl.py | 6 ++++-- .../pd_nixl/base_kv_move_manager.py | 2 -- .../pd_nixl/decode_node_impl/decode_impl.py | 5 ++--- .../decode_kv_move_manager.py | 9 +++----- .../decode_node_impl/decode_trans_process.py | 12 +++++------ .../pd_nixl/prefill_node_impl/prefill_impl.py | 5 ++--- .../prefill_kv_move_manager.py | 16 ++++++-------- .../prefill_trans_process.py | 10 +++++---- .../mode_backend/pd_nixl/trans_process_obj.py | 1 - 10 files changed, 43 insertions(+), 44 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index be14f2d87..ab5cf2fed 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -432,11 +432,16 @@ def copy_kv_from_other_dp_ranks( rank_in_dp=rank_in_dp, ) - def create_shm(self): + def create_shm(self, use_for_pd_trans: bool = True): + if use_for_pd_trans: + shm_name = f"{get_unique_server_name()}_mem_manager_for_pd_{get_current_rank_in_node()}" + else: + shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" + for rank_in_node in range(0, get_node_world_size()): obj_bytes = ForkingPickler.dumps(self) shm = create_or_link_shm( - f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}_{rank_in_node}", + f"{shm_name}_{rank_in_node}", LIGHTLLM_MEM_MANAGER_SHM_SIZE, ) logger.info(f"create shm {shm.name} size {shm.size} obj size {len(obj_bytes)}") @@ -444,11 +449,13 @@ def create_shm(self): shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes @staticmethod - def from_shm(mem_manager_rank_in_node, rank_in_node): - shm = create_or_link_shm( - f"{get_unique_server_name()}_mem_manager_{mem_manager_rank_in_node}_{rank_in_node}", - LIGHTLLM_MEM_MANAGER_SHM_SIZE, - ) + def from_shm(mem_manager_rank_in_node, rank_in_node, use_for_pd_trans: bool = True): + if use_for_pd_trans: + shm_name = f"{get_unique_server_name()}_mem_manager_for_pd_{mem_manager_rank_in_node}_{rank_in_node}" + else: + shm_name = f"{get_unique_server_name()}_mem_manager_{mem_manager_rank_in_node}_{rank_in_node}" + logger.info(f"from shm {shm_name}") + shm = create_or_link_shm(shm_name, LIGHTLLM_MEM_MANAGER_SHM_SIZE) bytes_len = int.from_bytes(shm.buf[0:4], "little") obj_bytes = shm.buf[4 : 4 + bytes_len].tobytes() shm.close() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 4fc15aa0f..d4eec1d0c 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -83,7 +83,7 @@ def init_custom(self): mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ # Create shared memory for mem_manager - self.model.mem_manager.create_shm() + self.model.mem_manager.create_shm(use_for_pd_trans=False) # Create shared ShmArray for kv_indexes transfer # Use a small buffer to save shared memory @@ -103,7 +103,9 @@ def init_custom(self): self.mem_managers = [] for rank_idx in range(self.node_world_size): if rank_idx != self.rank_in_node: - self.mem_managers.append(MemoryManager.from_shm(rank_idx, self.rank_in_node)) + self.mem_managers.append( + MemoryManager.from_shm(rank_idx, self.rank_in_node, use_for_pd_trans=False) + ) else: self.mem_managers.append(self.model.mem_manager) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py index e4460073a..eb1172802 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/base_kv_move_manager.py @@ -20,7 +20,6 @@ def __init__( self, args: StartArgs, info_queue: mp.Queue, - mem_queues: List[mp.Queue], start_trans_process_func: Callable, up_status_in_queue: Optional[mp.SimpleQueue] = None, ): @@ -33,7 +32,6 @@ def __init__( assert self.dp_world_size <= self.node_world_size self.info_queue = info_queue - self.mem_queues = mem_queues self.ret_obj_queue = queue.Queue() self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index 8529bf81b..8ecabbcf2 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -25,9 +25,8 @@ def init_custom(self): mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - # 将当前的内存管理器放入到队列中,供kv传输进程获取后使用 - for _ in range(self.node_world_size): - self.mem_queue.put(self.model.mem_manager) + # 将内存管理器写入共享内存,供kv传输进程获取后使用 + self.model.mem_manager.create_shm() return def _init_reqs(self, reqs: List[Tuple]): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py index 17d963e9a..877c5c12d 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py @@ -14,9 +14,9 @@ logger = init_logger(__name__) -def start_decode_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): +def start_decode_kv_move_manager_process(args, info_queue: mp.Queue): event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, mem_queues, event)) + proc = mp.Process(target=_init_env, args=(args, info_queue, event)) proc.start() event.wait() assert proc.is_alive() @@ -24,7 +24,7 @@ def start_decode_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues: return -def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): +def _init_env(args, info_queue: mp.Queue, event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ # 注册graceful 退出的处理 @@ -40,7 +40,6 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. manager = DecodeKVMoveManager( args=args, info_queue=info_queue, - mem_queues=mem_queues, start_trans_process_func=start_decode_trans_process, up_status_in_queue=up_status_in_queue, ) @@ -56,14 +55,12 @@ def __init__( self, args: StartArgs, info_queue: mp.Queue, - mem_queues: List[mp.Queue], start_trans_process_func: Callable, up_status_in_queue: mp.SimpleQueue, ): super().__init__( args=args, info_queue=info_queue, - mem_queues=mem_queues, start_trans_process_func=start_trans_process_func, up_status_in_queue=up_status_in_queue, ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index e7dc30ad8..9091f5ca9 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -31,12 +31,9 @@ def start_decode_trans_process( device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], up_status_in_queue: Optional[mp.SimpleQueue], ): - proc = mp.Process( - target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues, up_status_in_queue) - ) + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, up_status_in_queue)) proc.start() assert proc.is_alive() logger.info(f"prefill trans kv process for device: {device_id} started!") @@ -48,7 +45,6 @@ def _init_env( device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], up_status_in_queue: Optional[mp.SimpleQueue], ): torch.backends.cudnn.enabled = False @@ -58,7 +54,11 @@ def _init_env( graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") - mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + + # 从共享内存读取所有rank的mem_manager + node_world_size = args.tp // args.nnodes + mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] + task_out_queue.put("get_mem_managers_ok") manager = _DecodeTransModule( diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py index e9e61571a..16542bc37 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py @@ -27,9 +27,8 @@ def init_custom(self): mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - # 将当前的内存管理器放入到队列中,供kv传输进程获取后使用 - for _ in range(self.node_world_size): - self.mem_queue.put(self.model.mem_manager) + # 将内存管理器写入共享内存,供kv传输进程获取后使用 + self.model.mem_manager.create_shm() return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py index fb6845979..ac8026e58 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py @@ -13,9 +13,9 @@ logger = init_logger(__name__) -def start_prefill_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): +def start_prefill_kv_move_manager_process(args, info_queue: mp.Queue): event = mp.Event() - proc = mp.Process(target=_init_env, args=(args, info_queue, mem_queues, event)) + proc = mp.Process(target=_init_env, args=(args, info_queue, event)) proc.start() event.wait() assert proc.is_alive() @@ -23,7 +23,7 @@ def start_prefill_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues return -def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): +def _init_env(args, info_queue: mp.Queue, event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ # 注册graceful 退出的处理 @@ -32,7 +32,7 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. from .prefill_trans_process import start_prefill_trans_process manager = PrefillKVMoveManager( - args=args, info_queue=info_queue, mem_queues=mem_queues, start_trans_process_func=start_prefill_trans_process + args=args, info_queue=info_queue, start_trans_process_func=start_prefill_trans_process ) assert manager is not None event.set() @@ -42,12 +42,8 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. class PrefillKVMoveManager(BaseKVMoveManager): - def __init__( - self, args: StartArgs, info_queue: mp.Queue, mem_queues: List[mp.Queue], start_trans_process_func: Callable - ): - super().__init__( - args=args, info_queue=info_queue, mem_queues=mem_queues, start_trans_process_func=start_trans_process_func - ) + def __init__(self, args: StartArgs, info_queue: mp.Queue, start_trans_process_func: Callable): + super().__init__(args=args, info_queue=info_queue, start_trans_process_func=start_trans_process_func) return # ================================================================================== diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 8265afc27..bf6d80f06 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -25,10 +25,9 @@ def start_prefill_trans_process( device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], up_status_in_queue: Optional[mp.SimpleQueue] = None, ): - proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues)) + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue)) proc.start() assert proc.is_alive() logger.info(f"prefill trans kv process for device: {device_id} started!") @@ -40,7 +39,6 @@ def _init_env( device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], ): torch.backends.cudnn.enabled = False @@ -48,7 +46,11 @@ def _init_env( torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") - mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + + # 从共享内存读取所有rank的mem_manager + node_world_size = args.tp // args.nnodes + mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] + task_out_queue.put("get_mem_managers_ok") manager = _PrefillTransModule( diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py index 7afd2ca9a..073ecf23d 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/trans_process_obj.py @@ -35,7 +35,6 @@ def init_all( device_id, self.task_in_queue, self.task_out_queue, - manager.mem_queues, up_status_in_queue, ) assert self.task_out_queue.get(timeout=30) == "proc_start" From a97df661c6ebcd06204d6940a9688ceef156d53c Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Fri, 14 Nov 2025 04:32:29 +0000 Subject: [PATCH 18/34] minor change --- lightllm/server/api_start.py | 2 +- lightllm/utils/log_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 332a45283..9cc3d38c2 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -269,7 +269,7 @@ def normal_or_p_d_start(args): args.router_max_wait_tokens = 0 send_and_receive_node_ip(args) # 多机用于收发node ip - # PD 分离模式下必须禁用 DP prompt cache fetch,且 dp 必须 > 1 + # dp 必须 > 1 if args.enable_dp_prompt_cache_fetch and args.dp <= 1: args.enable_dp_prompt_cache_fetch = False logger.warning( diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index ce4588f52..f15309d5c 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -7,7 +7,7 @@ import time from typing import Optional -_FORMAT = "%(levelname)s %(asctime)s,%(msecs)03d [%(filename)s:%(lineno)d] %(message)s" +_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" _LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug") From 47212f0da05bcb74bc58486d2882c381b79cf32d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 15 Nov 2025 12:43:26 +0800 Subject: [PATCH 19/34] add test.py --- test.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 000000000..689ae9b77 --- /dev/null +++ b/test.py @@ -0,0 +1,51 @@ +import torch +import multiprocessing +from multiprocessing.reduction import ForkingPickler + + +def worker(serialized): + torch.cuda.set_device(0) + tensor = ForkingPickler.loads(serialized) + print("In worker process:", tensor, type(tensor)) + # import time + # time.sleep(100) + + +def worker1(serialized): + torch.cuda.set_device(1) + tensor = ForkingPickler.loads(serialized) + print("In worker process:", tensor, type(tensor)) + # import time + # time.sleep(100) + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + + torch.cuda.set_device(0) + # Create a tensor on the CUDA device + a = torch.zeros((100,), device="cuda").cuda() + + same = ForkingPickler.dumps(a) + serialized = same.tobytes() + + # Create a new process + process = multiprocessing.Process(target=worker, args=(serialized,)) + + # Start the process + process.start() + + process1 = multiprocessing.Process(target=worker1, args=(serialized,)) + + # Start the process + process1.start() + + # Wait for the process to finish + process.join() + process1.join() + print(a) + import time + + time.sleep(10) + + print("Main process finished.") From cbb1b842f801b9bd859f56a03bd7ba99a9b5e426 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 15 Nov 2025 12:58:07 +0800 Subject: [PATCH 20/34] improve mem_manager --- .../kv_cache_mem_manager/mem_manager.py | 38 ++++++++----------- lightllm/utils/shm_utils.py | 7 ++-- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index ab5cf2fed..08e7e95c7 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -19,7 +19,6 @@ from multiprocessing.reduction import ForkingPickler logger = init_logger(__name__) -LIGHTLLM_MEM_MANAGER_SHM_SIZE = int(os.getenv("LIGHTLLM_MEM_MANAGER_SHM_SIZE", 16 * 1024)) class MemoryManager: @@ -432,30 +431,23 @@ def copy_kv_from_other_dp_ranks( rank_in_dp=rank_in_dp, ) - def create_shm(self, use_for_pd_trans: bool = True): - if use_for_pd_trans: - shm_name = f"{get_unique_server_name()}_mem_manager_for_pd_{get_current_rank_in_node()}" - else: - shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" - - for rank_in_node in range(0, get_node_world_size()): - obj_bytes = ForkingPickler.dumps(self) - shm = create_or_link_shm( - f"{shm_name}_{rank_in_node}", - LIGHTLLM_MEM_MANAGER_SHM_SIZE, - ) - logger.info(f"create shm {shm.name} size {shm.size} obj size {len(obj_bytes)}") - shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") - shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes + def write_to_shm(self): + """ + 将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。 + """ + shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" + obj_bytes = ForkingPickler.dumps(self).tobytes() + shm = create_or_link_shm(name=shm_name, expected_size=len(obj_bytes) + 4, force_mode="create") + logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer") + shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") + shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes + self.__shm_io_buffer = shm @staticmethod - def from_shm(mem_manager_rank_in_node, rank_in_node, use_for_pd_trans: bool = True): - if use_for_pd_trans: - shm_name = f"{get_unique_server_name()}_mem_manager_for_pd_{mem_manager_rank_in_node}_{rank_in_node}" - else: - shm_name = f"{get_unique_server_name()}_mem_manager_{mem_manager_rank_in_node}_{rank_in_node}" - logger.info(f"from shm {shm_name}") - shm = create_or_link_shm(shm_name, LIGHTLLM_MEM_MANAGER_SHM_SIZE) + def loads_from_shm(rank_in_node: int) -> "MemoryManager": + shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}" + logger.info(f"get memmanager from shm {shm_name}") + shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link") bytes_len = int.from_bytes(shm.buf[0:4], "little") obj_bytes = shm.buf[4 : 4 + bytes_len].tobytes() shm.close() diff --git a/lightllm/utils/shm_utils.py b/lightllm/utils/shm_utils.py index 7ceea6f8b..0a25d8214 100644 --- a/lightllm/utils/shm_utils.py +++ b/lightllm/utils/shm_utils.py @@ -10,7 +10,7 @@ def create_or_link_shm(name, expected_size, force_mode=None, auto_cleanup=False) """ Args: name: name of the shared memory - expected_size: expected size of the shared memory + expected_size: expected size of the shared memory, if expected_size == -1, no check for size linked. force_mode: force mode - 'create': force create new shared memory, if exists, delete and create - 'link': force link to existing shared memory, if not exists, raise exception @@ -52,11 +52,12 @@ def _force_create_shm(name, expected_size, auto_cleanup): def _force_link_shm(name, expected_size): - """强制连接到已存在的共享内存""" + """强制连接到已存在的共享内存, + 如果 expected_size 为 -1, 则不进行link的size校验比对""" try: shm = shared_memory.SharedMemory(name=name) # 验证大小 - if shm.size != expected_size: + if expected_size != -1 and shm.size != expected_size: shm.close() raise ValueError(f"Shared memory {name} size mismatch: expected {expected_size}, got {shm.size}") # logger.info(f"Force linked to existing shared memory: {name} (size={expected_size})") From b548946e36092def61adceabd5e85fd40ed1d5b7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 15 Nov 2025 13:09:24 +0800 Subject: [PATCH 21/34] write mem manager to shm --- .../router/model_infer/mode_backend/base_backend.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 95f0c9951..11b7e1743 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -218,6 +218,14 @@ def init_model(self, kvargs): if self.args.mtp_mode: self.init_mtp_draft_model(kvargs) + # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 + # 读取 + if ( + self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"] + or self.args.enable_dp_prompt_cache_fetch + ): + self.model.mem_manager.write_to_shm() + # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) From 6270e0b24864a82931fa7f6ae182a8f013b4230f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Sat, 15 Nov 2025 18:02:43 +0800 Subject: [PATCH 22/34] fix --- lightllm/common/kv_cache_mem_manager/mem_manager.py | 5 +++++ .../pd_mode/decode_node_impl/decode_infer_rpyc.py | 5 ----- .../decode_node_impl/decode_kv_move_manager.py | 7 ------- .../pd_mode/decode_node_impl/decode_trans_obj.py | 5 ----- .../pd_mode/decode_node_impl/decode_trans_process.py | 7 +++---- .../pd_mode/prefill_node_impl/prefill_infer_rpyc.py | 5 ----- .../prefill_node_impl/prefill_kv_move_manager.py | 9 +-------- .../pd_mode/prefill_node_impl/prefill_trans_obj.py | 3 --- .../prefill_node_impl/prefill_trans_process.py | 7 +++---- .../model_infer/mode_backend/dp_backend/impl.py | 11 +---------- .../pd_nixl/decode_node_impl/decode_impl.py | 3 +-- .../pd_nixl/decode_node_impl/decode_trans_process.py | 4 +++- .../pd_nixl/prefill_node_impl/prefill_impl.py | 3 --- .../prefill_node_impl/prefill_trans_process.py | 4 +++- 14 files changed, 20 insertions(+), 58 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 08e7e95c7..511687258 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -2,6 +2,7 @@ import os import torch import torch.distributed as dist +import torch.multiprocessing as mp from typing import List, Union from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.pd_io_struct import KVMoveTask @@ -435,6 +436,10 @@ def write_to_shm(self): """ 将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。 """ + from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor + + mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ + shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" obj_bytes = ForkingPickler.dumps(self).tobytes() shm = create_or_link_shm(name=shm_name, expected_size=len(obj_bytes) + 4, force_mode="create") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py index 202b624e5..696452b41 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py @@ -166,11 +166,6 @@ def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]): release_acquired_lock() return - def exposed_put_mem_manager_to_shm(self): - self.backend.model.mem_manager.create_shm() - logger.info("put mem manager to shm ok") - return - def exposed_unfrozen_time_out_reqs_tokens(self): acquire_lock_until_ready(self.backend.lock_nccl_group) if self.backend.dp_world_size == 1: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 88904018b..4733a141b 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -86,7 +86,6 @@ def __init__(self, args, info_queue: mp.Queue): # _put_kv_received_to_radix_cache # _fail_to_realese_forzen_tokens # _unfrozen_time_out_reqs_tokens - # _put_mem_manager_to_shm # 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放 # kv资源的接口 # ================================================================================== @@ -154,12 +153,6 @@ def _unfrozen_time_out_reqs_tokens(self) -> None: asyncio.run(self.wait_all_future_finish(futures)) return - def _put_mem_manager_to_shm(self) -> None: - with self.infer_rpyc_lock: - for obj in self.infer_rpyc_objs: - obj.put_mem_manager_to_shm() - return - # ================================================================================== # put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到 # 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index 575032189..939f065fb 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -281,11 +281,6 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): self.task_out_queue, ) assert self.task_out_queue.get(timeout=30) == "proc_start" - # 确保在子进程读取共享内存之前,主进程已经将 mem_manager 写入共享内存 - if self.device_id == 0: - manager._put_mem_manager_to_shm() - # 通知子进程可以从共享内存读取 mem_manager - self.task_in_queue.put("mem_managers_ready") assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 84c669730..cdca63887 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -111,12 +111,11 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") - # 等待主进程将 mem_manager 写入共享内存后的信号 - assert task_in_queue.get(timeout=60) == "mem_managers_ready" - # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] + mem_managers: List[MemoryManager] = [ + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + ] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py index 12c6e9471..1f2dd52c5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py @@ -45,8 +45,3 @@ def exposed_remove_req_refs_from_prompt_cache(self, group_req_ids: List[int]): ) release_acquired_lock() return - - def exposed_put_mem_manager_to_shm(self): - self.backend.model.mem_manager.create_shm() - logger.info("put mem manager to shm ok") - return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index b0d1851c4..bd5af98ee 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -142,8 +142,7 @@ def check_trans_process_loop(self): raise e # ================================================================================== - # 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和 - # _put_mem_manager_to_shm 都是通过 rpyc 与推理进程进行交互的接口 + # 与推理进程交互接口, _remove_req_refs_from_prompt_cache # ================================================================================== def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): @@ -163,12 +162,6 @@ def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): asyncio.run(self.wait_all_future_finish(futures)) return - def _put_mem_manager_to_shm(self): - with self.infer_rpyc_lock: - for obj in self.infer_rpyc_objs: - obj.put_mem_manager_to_shm() - return - async def wait_all_future_finish(self, futures: List[AsyncResult]): await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index cdfacaecc..022be4559 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -355,9 +355,6 @@ def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): self.task_out_queue, ) assert self.task_out_queue.get(timeout=30) == "proc_start" - if self.device_id == 0: - manager._put_mem_manager_to_shm() - self.task_in_queue.put("mem_managers_ready") assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index deffcbaf3..a328e3e08 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -116,12 +116,11 @@ def _init_env( ) task_out_queue.put("proc_start") - # 等待主进程将 mem_manager 写入共享内存后的信号 - assert task_in_queue.get(timeout=60) == "mem_managers_ready" - # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] + mem_managers: List[MemoryManager] = [ + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + ] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index d4eec1d0c..8ee8add4f 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -27,7 +27,6 @@ from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids from .control_state import DPControlState from lightllm.common.mem_manager import MemoryManager -import torch.multiprocessing as mp min_trans_token_num = int(os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", "512")) dp_kv_transfer_req_num = int(os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", "16")) @@ -76,15 +75,9 @@ def init_custom(self): if self.enable_dp_prompt_cache_fetch: torch.cuda.set_device(get_current_device_id()) - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor from lightllm.server.core.objs.shm_array import ShmArray from lightllm.utils.envs_utils import get_unique_server_name - mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - - # Create shared memory for mem_manager - self.model.mem_manager.create_shm(use_for_pd_trans=False) - # Create shared ShmArray for kv_indexes transfer # Use a small buffer to save shared memory self.dp_kv_transfer_req_num = dp_kv_transfer_req_num @@ -103,9 +96,7 @@ def init_custom(self): self.mem_managers = [] for rank_idx in range(self.node_world_size): if rank_idx != self.rank_in_node: - self.mem_managers.append( - MemoryManager.from_shm(rank_idx, self.rank_in_node, use_for_pd_trans=False) - ) + self.mem_managers.append(MemoryManager.loads_from_shm(self.rank_in_node)) else: self.mem_managers.append(self.model.mem_manager) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index 8ecabbcf2..07d99a3dc 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -25,8 +25,7 @@ def init_custom(self): mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - # 将内存管理器写入共享内存,供kv传输进程获取后使用 - self.model.mem_manager.create_shm() + # TODO 如何支持不支持 P2P的场景 return def _init_reqs(self, reqs: List[Tuple]): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index 9091f5ca9..b04cbb900 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -57,7 +57,9 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] + mem_managers: List[MemoryManager] = [ + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + ] task_out_queue.put("get_mem_managers_ok") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py index 16542bc37..55554cdf6 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py @@ -26,9 +26,6 @@ def init_custom(self): from ..p2p_fix import reduce_tensor mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - - # 将内存管理器写入共享内存,供kv传输进程获取后使用 - self.model.mem_manager.create_shm() return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index bf6d80f06..063ce5c6a 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -49,7 +49,9 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes - mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)] + mem_managers: List[MemoryManager] = [ + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + ] task_out_queue.put("get_mem_managers_ok") From e1769a335e3868ce65753ed46370826bc5416fd4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 17 Nov 2025 20:13:20 +0800 Subject: [PATCH 23/34] fix --- docs/CN/source/tutorial/deepseek_deployment.rst | 4 ++-- docs/EN/source/tutorial/deepseek_deployment.rst | 4 ++-- .../pd_mode/decode_node_impl/decode_impl.py | 7 ------- .../pd_mode/prefill_node_impl/prefill_impl.py | 7 ------- .../mode_backend/pd_nixl/decode_node_impl/decode_impl.py | 5 ----- .../mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py | 5 ----- lightllm/utils/device_utils.py | 2 +- test/start_scripts/README.md | 2 +- test/start_scripts/multi_pd_master.sh | 4 ++-- test/start_scripts/single_pd_master/pd_decode.sh | 2 +- test/start_scripts/single_pd_master/pd_nixl_decode.sh | 2 +- test/start_scripts/single_pd_master/pd_nixl_prefill.sh | 2 +- test/start_scripts/single_pd_master/pd_prefill.sh | 2 +- 13 files changed, 12 insertions(+), 36 deletions(-) diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index 4c83fd76c..071d9405a 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -187,7 +187,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ + MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -211,7 +211,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ + MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index d924dfed3..6098411be 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -187,7 +187,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ + MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -208,7 +208,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ + MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py index f867d512d..b367a66a7 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py @@ -11,7 +11,6 @@ from rpyc.utils.server import ThreadedServer from lightllm.common.basemodel.infer_lock import g_router_lock from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask -from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.dist_utils import create_new_group_for_current_dp @@ -39,12 +38,6 @@ def init_custom(self): PDDecodeInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True} ) threading.Thread(target=lambda: t.start(), daemon=True).start() - - if kv_trans_use_p2p(): - from ..p2p_fix import reduce_tensor - - mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - return def _init_reqs(self, reqs: List[Tuple]): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py index 441fc5cd8..8e7bddc64 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py @@ -11,7 +11,6 @@ from lightllm.common.basemodel.infer_lock import g_router_lock, g_infer_state_lock from rpyc.utils.server import ThreadedServer from .prefill_task_cache import g_kv_move_task_cache -from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.dist_utils import create_new_group_for_current_dp from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend @@ -41,12 +40,6 @@ def init_custom(self): PDPrefillInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True} ) threading.Thread(target=lambda: t.start(), daemon=True).start() - - if kv_trans_use_p2p(): - from ..p2p_fix import reduce_tensor - - mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ - return def _pre_handle_finished_reqs(self, finished_reqs): diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index 07d99a3dc..f1309ca9c 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -18,12 +18,7 @@ def __init__(self, info_queue: mp.Queue) -> None: self.classed_req_strict_prefill = False def init_custom(self): - assert kv_trans_use_p2p() - if kv_trans_use_p2p(): - from ..p2p_fix import reduce_tensor - - mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ # TODO 如何支持不支持 P2P的场景 return diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py index 55554cdf6..6f5a6e17d 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py @@ -21,11 +21,6 @@ def __init__(self, info_queue: mp.Queue) -> None: def init_custom(self): assert kv_trans_use_p2p() - - if kv_trans_use_p2p(): - from ..p2p_fix import reduce_tensor - - mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 4c32f2ab1..022c5ab40 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -112,7 +112,7 @@ def init_p2p(device_index): @lru_cache(maxsize=None) def kv_trans_use_p2p(): - return os.getenv("KV_TRANS_USE_P2P", "False").upper() in ["1", "TRUE", "ON"] + return not (os.getenv("DISABLE_KV_TRANS_USE_P2P", "False").upper() in ["1", "TRUE", "ON"]) def has_nvlink(): diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index aff1d973f..f5dae19b9 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -100,7 +100,7 @@ sh multi_pd_master/pd_decode.sh - `LOADWORKER`: Model loading thread count, recommended 8-18 - `MOE_MODE`: Expert parallelism mode, set to EP to enable expert parallelism -- `KV_TRANS_USE_P2P`: Enable P2P communication optimization +- `DISABLE_KV_TRANS_USE_P2P`: Disable P2P communication optimization to transfer kv data - `CUDA_VISIBLE_DEVICES`: Specify GPU devices to use ### Important Parameters diff --git a/test/start_scripts/multi_pd_master.sh b/test/start_scripts/multi_pd_master.sh index c4e8c21fb..7b8392392 100644 --- a/test/start_scripts/multi_pd_master.sh +++ b/test/start_scripts/multi_pd_master.sh @@ -6,7 +6,7 @@ python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Ch python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat --run_mode "pd_master" --host 10.120.114.74 --port 60012 --config_server_host 10.120.114.74 --config_server_port 60088 nvidia-cuda-mps-control -d -CUDA_VISIBLE_DEVICES=0 KV_TRANS_USE_P2P=1 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ +CUDA_VISIBLE_DEVICES=0 LOADWORKER=1 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ --run_mode "prefill" \ --host 10.120.178.74 \ --port 8019 \ @@ -20,7 +20,7 @@ CUDA_VISIBLE_DEVICES=0 KV_TRANS_USE_P2P=1 LOADWORKER=1 python -m lightllm.server --config_server_host 10.120.114.74 \ --config_server_port 60088 -CUDA_VISIBLE_DEVICES=1 KV_TRANS_USE_P2P=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ +CUDA_VISIBLE_DEVICES=1 LOADWORKER=10 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-V2-Lite-Chat \ --run_mode "decode" \ --host 10.120.178.74 \ --port 8121 \ diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index 1bf465746..ae16b96ad 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -5,7 +5,7 @@ export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ +MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh index 1fe677195..1b43c11cc 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_decode.sh @@ -10,7 +10,7 @@ export UCX_LOG_LEVEL=info export UCX_TLS=rc,cuda,gdr_copy nvidia-cuda-mps-control -d -MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ +MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "nixl_decode" \ --tp 8 \ diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh index c583b4d7c..303de2975 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh @@ -11,7 +11,7 @@ export UCX_TLS=rc,cuda,gdr_copy export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ +MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "nixl_prefill" \ --tp 8 \ diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index b15e4ef70..f6e2e4b68 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -5,7 +5,7 @@ export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \ +MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ From c4f780f64110f43300e2f6d49d9fffc3df55dd6d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 18 Nov 2025 08:18:00 +0000 Subject: [PATCH 24/34] fix --- .../kv_cache_mem_manager/mem_manager.py | 16 +- .../model_infer/mode_backend/base_backend.py | 2 +- .../dp_backend/dp_shared_kv_trans.py | 136 ++++++++++++ .../mode_backend/dp_backend/impl.py | 197 ++---------------- 4 files changed, 171 insertions(+), 180 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 511687258..a26a07109 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -16,6 +16,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io +from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm from multiprocessing.reduction import ForkingPickler @@ -432,13 +433,22 @@ def copy_kv_from_other_dp_ranks( rank_in_dp=rank_in_dp, ) - def write_to_shm(self): + def write_to_shm(self, req_manager): """ 将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。 """ - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor + if kv_trans_use_p2p(): + from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor - mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ + mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ + + from lightllm.common.req_manager import ReqManager + + req_manager: ReqManager = req_manager + + # 这个地方是一个不太优雅的设计,但是暂时这么做,可以让dp shared kv swap模块直接访问 req_manager 中的 req_to_token_indexs + # 避免过多无用的数据复制和传输开销。 + self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" obj_bytes = ForkingPickler.dumps(self).tobytes() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 11b7e1743..e36c17de8 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -224,7 +224,7 @@ def init_model(self, kvargs): self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"] or self.args.enable_dp_prompt_cache_fetch ): - self.model.mem_manager.write_to_shm() + self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager) # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py new file mode 100644 index 000000000..e11af4f0b --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py @@ -0,0 +1,136 @@ +# 该文件用于提供在数据dp并行的推理模式下,共享kv cache trans相关的功能函数模块 +import numpy as np +import dataclasses +import torch +from typing import List +from lightllm.common.mem_manager import MemoryManager +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.dist_utils import get_dp_rank_in_node +from lightllm.server.core.objs.shm_array import ShmArray +from ...infer_batch import InferReq + + +class DPKVSharedMoudle: + _KV_LEN_INDEX = 0 + _REQ_IDX_INDEX = 1 + + def __init__(self, max_req_num: int, max_req_seq_len: int, dp_size_in_node: int, backend): + from .impl import DPChunkedPrefillBackend + + self.backend: DPChunkedPrefillBackend = backend + self.max_req_num = max_req_num + self.max_req_seq_len = max_req_seq_len + + # 0 代表 kv_len, 1 代表 radix_cache_len + self.shared_req_infos = ShmArray( + name=f"{get_unique_server_name()}_dp_shared_req_infos", + shape=(self.max_req_num, dp_size_in_node, 2), + dtype=np.int64, + ) + self.shared_req_infos.create_shm() + self.dp_rank_in_node = get_dp_rank_in_node() + assert get_env_start_args().diverse_mode is False + + def fill_reqs_info( + self, + reqs: List[InferReq], + req_dp_ranks: List[int], + ): + """ + 填充请求的 kv 信息到共享内存中 + """ + self.backend.node_nccl_group.barrier() + self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._KV_LEN_INDEX] = [ + req.cur_kv_len for req in reqs + ] + self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._REQ_IDX_INDEX] = [ + req.req_idx for req in reqs + ] + return + + def build_shared_kv_trans_tasks( + self, + reqs: List[InferReq], + req_dp_ranks: List[int], + ) -> List["TransTask"]: + """ + 构建共享kv交换信息 + """ + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + self.backend.node_nccl_group.barrier() + + trans_tasks: List[TransTask] = [] + rank_max_radix_cache_lens = np.max( + self.shared_req_infos.arr[0 : len(reqs), :, self._KV_LEN_INDEX], axis=1, keepdims=False + ) + # 如果发现自己是dp_rank 最小, radix_cache_len 最长的请求,则将数据写入到共享内存中。 + for req_index, req, max_req_radix_cache_len, req_dp_rank in zip( + list(range(len(reqs))), reqs, rank_max_radix_cache_lens, req_dp_ranks + ): + # 当前请求是本 dp_rank 负责的 + is_current_dp_handle = req_dp_rank == self.dp_rank_in_node + trans_size = max_req_radix_cache_len - req.cur_kv_len + + if is_current_dp_handle and trans_size > 0 and g_infer_context.get_can_alloc_token_num() > trans_size: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(trans_size) + mem_indexes = self.backend.model.mem_manager.alloc(trans_size) + max_kv_len_dp_rank = self.shared_req_infos.arr[req_index, :, self._KV_LEN_INDEX].argmax() + max_kv_len_req_idx = int(self.shared_req_infos.arr[req_index, max_kv_len_dp_rank, self._REQ_IDX_INDEX]) + max_kv_len_mem_manager_index = ( + max_kv_len_dp_rank * self.backend.dp_world_size + self.backend.dp_rank_in_node + ) + max_kv_len_mem_manager: MemoryManager = self.backend.mem_managers[max_kv_len_mem_manager_index] + max_kv_len_mem_indexes = max_kv_len_mem_manager.req_to_token_indexs[ + max_kv_len_req_idx, req.cur_kv_len : max_req_radix_cache_len + ] + trans_tasks.append( + TransTask( + req=req, + mem_indexes=mem_indexes, + max_kv_len_dp_rank=int(max_kv_len_dp_rank), + max_kv_len_mem_manager_index=int(max_kv_len_mem_manager_index), + max_kv_len_mem_indexes=max_kv_len_mem_indexes, + ) + ) + + return trans_tasks + + def kv_trans(self, trans_tasks: List["TransTask"]): + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + self.backend.node_nccl_group.barrier() + # kv 传输 + + # move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + # token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") + + # self.model.mem_manager.copy_kv_from_other_dp_ranks( + # mem_managers=self.mem_managers, + # move_token_indexes=move_token_indexes, + # token_dp_indexes=token_dp_indexes, + # mem_indexes=mem_indexes, + # dp_size_in_node=self.dp_size_in_node, + # rank_in_dp=self.rank_in_dp, + # ) + # self.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {alloc_size}") + + self.backend.node_nccl_group.barrier() + for trans_task in trans_tasks: + g_infer_context.req_manager.req_to_token_indexs[ + trans_task.req.req_idx, + trans_task.req.cur_kv_len : (trans_task.req.cur_kv_len + len(trans_task.mem_indexes)), + ] = trans_task.mem_indexes + trans_task.req.cur_kv_len += len(trans_task.mem_indexes) + if self.backend.is_master_in_dp: + trans_task.req.shm_req.shm_cur_kv_len = trans_task.req.cur_kv_len + self.backend.node_nccl_group.barrier() + + +@dataclasses +class TransTask: + req: InferReq + mem_indexes: torch.Tensor + max_kv_len_dp_rank: int + max_kv_len_mem_manager_index: int + max_kv_len_mem_indexes: torch.Tensor diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 8ee8add4f..3ba1dfe63 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -27,9 +27,7 @@ from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids from .control_state import DPControlState from lightllm.common.mem_manager import MemoryManager - -min_trans_token_num = int(os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", "512")) -dp_kv_transfer_req_num = int(os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", "16")) +from .dp_shared_kv_trans import DPKVSharedMoudle class DPChunkedPrefillBackend(ModeBackend): @@ -39,7 +37,6 @@ def __init__(self) -> None: # 用于控制每一步是执行prefill 和 decode 还是跳过 self.control_state_machine = DPControlState(backend=self) self.enable_dp_prompt_cache_fetch = get_env_start_args().enable_dp_prompt_cache_fetch - self.min_trans_token_num = min_trans_token_num # 在 mtp 模式下切换绑定的prefill 和 decode 函数 if get_env_start_args().mtp_mode: @@ -75,21 +72,12 @@ def init_custom(self): if self.enable_dp_prompt_cache_fetch: torch.cuda.set_device(get_current_device_id()) - from lightllm.server.core.objs.shm_array import ShmArray - from lightllm.utils.envs_utils import get_unique_server_name - - # Create shared ShmArray for kv_indexes transfer - # Use a small buffer to save shared memory - self.dp_kv_transfer_req_num = dp_kv_transfer_req_num - max_len = get_env_start_args().max_req_total_len + 8 - self.shared_kv_indexes_name = f"{get_unique_server_name()}_shared_kv_indexes_global" - self.shared_kv_indexes = ShmArray( - self.shared_kv_indexes_name, (self.dp_kv_transfer_req_num, max_len), dtype=np.int64 + self.dp_kv_shared_moudle = DPKVSharedMoudle( + max_req_num=self.args.running_max_req_size, + max_req_seq_len=self.args.max_req_total_len + 8, + dp_size_in_node=self.dp_size_in_node, + backend=self, ) - # Only rank_in_node == 0 creates the shared memory - if self.rank_in_node == 0: - self.shared_kv_indexes.create_shm() - dist.barrier(group=self.node_nccl_group) # Collect mem_managers from all ranks @@ -99,172 +87,29 @@ def init_custom(self): self.mem_managers.append(MemoryManager.loads_from_shm(self.rank_in_node)) else: self.mem_managers.append(self.model.mem_manager) - - # Other ranks link to the shared memory - if self.rank_in_node != 0: - self.shared_kv_indexes.link_shm() return def _init_reqs(self, reqs: List[Tuple]): - my_reqs = reqs - other_reqs = [] - if self.dp_size_in_node != 1: - dp_rank_in_node = self.dp_rank_in_node - my_reqs = [req for req in reqs if req[3] == dp_rank_in_node] - other_reqs = [req for req in reqs if req[3] != dp_rank_in_node] + if not self.args.enable_dp_prompt_cache_fetch: + return super()._init_reqs(reqs) - g_infer_state_lock.acquire() - infer_reqs = g_infer_context.add_reqs(my_reqs, init_prefix_cache=not self.enable_dp_prompt_cache_fetch) - if self.enable_dp_prompt_cache_fetch: - self._fetch_dp_prompt_cache(infer_reqs, other_reqs=other_reqs, origin_reqs=reqs) - for r in infer_reqs: - r._match_radix_cache() + dp_rank_in_node = self.dp_rank_in_node + current_dp_reqs = [req for req in reqs if req[3] == dp_rank_in_node] + other_dp_reqs = [req for req in reqs if req[3] != dp_rank_in_node] - g_infer_state_lock.release() - req_ids = [e[0] for e in my_reqs] - return req_ids - - def _match_radix_cache(self, shm_req): - input_token_ids = shm_req.shm_prompt_ids.arr[0 : shm_req.input_len] - key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") - key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - _, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=False) - return kv_len, value_tensor - - def _fetch_dp_prompt_cache( - self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = [], origin_reqs: List[Tuple] = [] - ): - # shm_index_2_index = {r[1]: i for i, r in enumerate(origin_reqs)} - my_match = [] - other_match = [] - # match all the reqs in this dp rank. - for r in infer_reqs: - if r.sampling_param.disable_prompt_cache: - continue - shm_req = r.shm_req - - kv_len, value_tensor = self._match_radix_cache(shm_req) - # only the first rank is ok - if self.rank_in_dp == 0: - with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): - shm_req.dp_origin_kv_len = kv_len - if kv_len > shm_req.dp_max_kv_len: - shm_req.dp_max_kv_len = kv_len - shm_req.dp_max_kv_rank = self.dp_rank_in_node # 单机 - my_match.append((shm_req, kv_len, value_tensor)) - - # match all the reqs in other dp ranks. - other_shm_reqs = [] - if self.rank_in_dp == 0: - for r in other_reqs: - _, shm_index, _, _ = r - shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(shm_index) - other_shm_reqs.append(shm_req) - sampling_param = InferSamplingParams(shm_req, g_infer_context.vocab_size) - if sampling_param.disable_prompt_cache: - continue - shm_req.link_prompt_ids_shm_array() - - kv_len, value_tensor = self._match_radix_cache(shm_req) - with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem): - if kv_len > shm_req.dp_max_kv_len: - shm_req.dp_max_kv_len = kv_len - shm_req.dp_max_kv_rank = self.dp_rank_in_node # 单机 - other_match.append((shm_req, kv_len, value_tensor)) - - # wait all the ranks to finish the match - dist.barrier(group=self.node_nccl_group) - - # 创建 shm_index 到匹配结果的映射 - shm_index_to_match = {match[0].index_in_shm_mem: match for match in my_match} - shm_index_to_match.update({match[0].index_in_shm_mem: match for match in other_match}) - - my_trans_match = [] - other_trans_match = [] - transfer_count = 0 - for r in origin_reqs: - _, shm_index, _, suggested_dp_index = r - shm_req, kv_len, value_tensor = shm_index_to_match[shm_index] - match = (shm_req, kv_len, value_tensor, suggested_dp_index) - - # 需要传输的 - if suggested_dp_index != shm_req.dp_max_kv_rank: - # 需要获取的 - if suggested_dp_index == self.dp_rank_in_node: - my_trans_match.append((match, transfer_count)) - # 需要给其他dp的 - else: - other_trans_match.append((match, transfer_count)) - transfer_count += 1 - - if transfer_count == self.dp_kv_transfer_req_num: - self._transfer_dp_kv_cache(my_trans_match, other_trans_match) - my_trans_match = [] - other_trans_match = [] - transfer_count = 0 - - if transfer_count > 0: - self._transfer_dp_kv_cache(my_trans_match, other_trans_match) - - self.release_all_shm_reqs(other_shm_reqs) - - def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple]): - for match, index in other_match: - shm_req, kv_len, value_tensor, _ = match - trans_len = kv_len - shm_req.dp_origin_kv_len - if shm_req.dp_max_kv_rank == self.dp_rank_in_node: - self.shared_kv_indexes.arr[index, 0:trans_len] = value_tensor[shm_req.dp_origin_kv_len : kv_len] - - dist.barrier(group=self.node_nccl_group) - - if not my_match: - return + g_infer_state_lock.acquire() - move_token_indexes = [] - token_dp_indexes = [] - trans_info = [] - alloc_size = 0 - for match, index in my_match: - shm_req, kv_len, value_tensor, _ = match - trans_len = shm_req.dp_max_kv_len - kv_len - if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node: - move_token_indexes.extend(self.shared_kv_indexes.arr[index, 0:trans_len]) - token_dp_indexes.extend([shm_req.dp_max_kv_rank] * trans_len) - trans_info.append((shm_req, kv_len, value_tensor, trans_len)) - alloc_size += trans_len - - if alloc_size < self.min_trans_token_num: - return + infer_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) + req_dp_ranks = [req[3] for req in reqs] + self.dp_kv_shared_moudle.fill_reqs_info(reqs=infer_reqs, req_dp_ranks=req_dp_ranks) + trans_taskes = self.dp_kv_shared_moudle.build_shared_kv_trans_tasks(reqs=infer_reqs, req_dp_ranks=req_dp_ranks) + self.dp_kv_shared_moudle.kv_trans(trans_tasks=trans_taskes) - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(alloc_size) - mem_indexes = self.model.mem_manager.alloc(alloc_size).cuda() - - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") - - self.model.mem_manager.copy_kv_from_other_dp_ranks( - mem_managers=self.mem_managers, - move_token_indexes=move_token_indexes, - token_dp_indexes=token_dp_indexes, - mem_indexes=mem_indexes, - dp_size_in_node=self.dp_size_in_node, - rank_in_dp=self.rank_in_dp, - ) - self.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {alloc_size}") - - start_index = 0 - for shm_req, kv_len, value_tensor, trans_len in trans_info: - new_value_tensor = mem_indexes[start_index : start_index + trans_len].cpu() - start_index += trans_len - key = torch.tensor(shm_req.shm_prompt_ids.arr[0 : shm_req.dp_max_kv_len], dtype=torch.int64, device="cpu") - value_tensor = ( - torch.cat((value_tensor, new_value_tensor), dim=0) if value_tensor is not None else new_value_tensor - ) - g_infer_context.radix_cache.insert(key, value_tensor) + g_infer_context._filter(finished_request_ids=[req[0] for req in other_dp_reqs]) + g_infer_state_lock.release() - def release_all_shm_reqs(self, shm_reqs): - for shm_req in shm_reqs: - g_infer_context.shm_req_manager.put_back_req_obj(shm_req) + req_ids = [e[0] for e in current_dp_reqs] + return req_ids def infer_loop(self): torch.cuda.set_device(get_current_device_id()) From 1e1e18a80d9ba732c8a8f548ec19abeaf21d3057 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 18 Nov 2025 08:30:43 +0000 Subject: [PATCH 25/34] fix --- lightllm/common/kv_cache_mem_manager/mem_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index a26a07109..22cb333e9 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -420,11 +420,11 @@ def copy_kv_from_other_dp_ranks( mems_ptr_list = [] for i in range(0, len(mem_managers)): mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr()) - self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cuda") + self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True) # 一次性传输所有层 kv_trans_for_dp( - input_mems=self.mem_ptrs_tensor, + input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True), input_idx=move_token_indexes, input_dp_idx=token_dp_indexes, output=self.kv_buffer, From c0f567e0198b162dbdb29bc0a22b3c812183649e Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Fri, 21 Nov 2025 02:23:12 +0000 Subject: [PATCH 26/34] fix --- .../kv_cache_mem_manager/mem_manager.py | 18 +- .../model_infer/mode_backend/base_backend.py | 16 +- .../decode_node_impl/decode_trans_process.py | 3 +- .../prefill_node_impl/prefill_impl_for_dp.py | 1 - .../prefill_trans_process.py | 3 +- .../dp_backend/dp_shared_kv_trans.py | 84 ++++++---- .../mode_backend/dp_backend/impl.py | 23 +-- .../decode_node_impl/decode_trans_process.py | 3 +- .../mode_backend/pd_nixl/p2p_fix.py | 155 ------------------ .../prefill_node_impl/prefill_impl_for_dp.py | 1 - .../prefill_trans_process.py | 3 +- 11 files changed, 86 insertions(+), 224 deletions(-) delete mode 100644 lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 22cb333e9..332fa6dcb 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -451,16 +451,18 @@ def write_to_shm(self, req_manager): self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" - obj_bytes = ForkingPickler.dumps(self).tobytes() - shm = create_or_link_shm(name=shm_name, expected_size=len(obj_bytes) + 4, force_mode="create") - logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer") - shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") - shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes - self.__shm_io_buffer = shm + for rank_in_node in range(0, get_node_world_size() * 2): + obj_bytes = ForkingPickler.dumps(self).tobytes() + shm = create_or_link_shm( + name=f"{shm_name}_{rank_in_node}", expected_size=len(obj_bytes) + 4, force_mode="create" + ) + logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer") + shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") + shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes @staticmethod - def loads_from_shm(rank_in_node: int) -> "MemoryManager": - shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}" + def loads_from_shm(rank_in_node: int, current_rank_in_node: int) -> "MemoryManager": + shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}_{current_rank_in_node}" logger.info(f"get memmanager from shm {shm_name}") shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link") bytes_len = int.from_bytes(shm.buf[0:4], "little") diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e36c17de8..706adbb47 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -39,6 +39,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from .dp_backend.dp_shared_kv_trans import init_dp_kv_shared class ModeBackend: @@ -209,7 +210,17 @@ def init_model(self, kvargs): [rank for rank in range(self.global_world_size)], backend="nccl" ) + if ( + self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"] + or self.args.enable_dp_prompt_cache_fetch + ): + self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager) + self.init_custom() + + if self.args.enable_dp_prompt_cache_fetch: + init_dp_kv_shared(self) + self.shm_reqs_io_buffer = ShmObjsIOBuffer() # 只会在 nixl pd 模式下才会使用,用于上传分块传输任务是否成功。 self.shm_nixl_trans_io_buffer = ShmObjsIOBuffer(tail_str="nixl") @@ -220,11 +231,6 @@ def init_model(self, kvargs): # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 # 读取 - if ( - self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"] - or self.args.enable_dp_prompt_cache_fetch - ): - self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager) # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index cdca63887..7023176e8 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -114,7 +114,8 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) + for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py index 4e2c35153..2897f7141 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py @@ -17,7 +17,6 @@ def __init__(self, info_queue: mp.Queue) -> None: def init_custom(self): ChunckedPrefillForPrefillNode.init_custom(self) - super().init_custom() return def _pre_handle_finished_reqs(self, finished_reqs): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index a328e3e08..f411e96a6 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -119,7 +119,8 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) + for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py index e11af4f0b..ccab6934d 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py @@ -8,6 +8,7 @@ from lightllm.utils.dist_utils import get_dp_rank_in_node from lightllm.server.core.objs.shm_array import ShmArray from ...infer_batch import InferReq +from lightllm.utils.dist_utils import get_current_device_id class DPKVSharedMoudle: @@ -31,21 +32,18 @@ def __init__(self, max_req_num: int, max_req_seq_len: int, dp_size_in_node: int, self.dp_rank_in_node = get_dp_rank_in_node() assert get_env_start_args().diverse_mode is False - def fill_reqs_info( - self, - reqs: List[InferReq], - req_dp_ranks: List[int], - ): + def fill_reqs_info(self, reqs: List[InferReq]): """ 填充请求的 kv 信息到共享内存中 """ self.backend.node_nccl_group.barrier() - self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._KV_LEN_INDEX] = [ - req.cur_kv_len for req in reqs - ] - self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._REQ_IDX_INDEX] = [ - req.req_idx for req in reqs - ] + if self.backend.rank_in_dp == 0: + self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._KV_LEN_INDEX] = [ + req.cur_kv_len for req in reqs + ] + self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._REQ_IDX_INDEX] = [ + req.req_idx for req in reqs + ] return def build_shared_kv_trans_tasks( @@ -77,9 +75,7 @@ def build_shared_kv_trans_tasks( mem_indexes = self.backend.model.mem_manager.alloc(trans_size) max_kv_len_dp_rank = self.shared_req_infos.arr[req_index, :, self._KV_LEN_INDEX].argmax() max_kv_len_req_idx = int(self.shared_req_infos.arr[req_index, max_kv_len_dp_rank, self._REQ_IDX_INDEX]) - max_kv_len_mem_manager_index = ( - max_kv_len_dp_rank * self.backend.dp_world_size + self.backend.dp_rank_in_node - ) + max_kv_len_mem_manager_index = max_kv_len_dp_rank * self.backend.dp_world_size + self.backend.rank_in_dp max_kv_len_mem_manager: MemoryManager = self.backend.mem_managers[max_kv_len_mem_manager_index] max_kv_len_mem_indexes = max_kv_len_mem_manager.req_to_token_indexs[ max_kv_len_req_idx, req.cur_kv_len : max_req_radix_cache_len @@ -101,19 +97,29 @@ def kv_trans(self, trans_tasks: List["TransTask"]): self.backend.node_nccl_group.barrier() # kv 传输 - - # move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - # token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") - - # self.model.mem_manager.copy_kv_from_other_dp_ranks( - # mem_managers=self.mem_managers, - # move_token_indexes=move_token_indexes, - # token_dp_indexes=token_dp_indexes, - # mem_indexes=mem_indexes, - # dp_size_in_node=self.dp_size_in_node, - # rank_in_dp=self.rank_in_dp, - # ) - # self.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {alloc_size}") + if len(trans_tasks) > 0: + max_kv_len_mem_indexes = [] + max_kv_len_dp_ranks = [] + mem_indexes = [] + + for i, trans_task in enumerate(trans_tasks): + max_kv_len_mem_indexes.extend(trans_task.max_kv_len_mem_indexes) + max_kv_len_dp_ranks.extend([trans_task.max_kv_len_dp_rank] * len(trans_task.max_kv_len_mem_indexes)) + mem_indexes.extend(trans_task.mem_indexes) + + max_kv_len_mem_indexes_tensor = torch.tensor(max_kv_len_mem_indexes, dtype=torch.int64, device="cuda") + max_kv_len_dp_ranks_tensor = torch.tensor(max_kv_len_dp_ranks, dtype=torch.int32, device="cuda") + mem_indexes_tensor = torch.tensor(mem_indexes, dtype=torch.int64, device="cuda") + + self.backend.model.mem_manager.copy_kv_from_other_dp_ranks( + mem_managers=self.backend.mem_managers, + move_token_indexes=max_kv_len_mem_indexes_tensor, + token_dp_indexes=max_kv_len_dp_ranks_tensor, + mem_indexes=mem_indexes_tensor, + dp_size_in_node=self.backend.dp_size_in_node, + rank_in_dp=self.backend.rank_in_dp, + ) + self.backend.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {len(mem_indexes_tensor)}") self.backend.node_nccl_group.barrier() for trans_task in trans_tasks: @@ -127,10 +133,32 @@ def kv_trans(self, trans_tasks: List["TransTask"]): self.backend.node_nccl_group.barrier() -@dataclasses +@dataclasses.dataclass class TransTask: req: InferReq mem_indexes: torch.Tensor max_kv_len_dp_rank: int max_kv_len_mem_manager_index: int max_kv_len_mem_indexes: torch.Tensor + + +def init_dp_kv_shared(backend): + if backend.enable_dp_prompt_cache_fetch: + torch.cuda.set_device(get_current_device_id()) + + backend.dp_kv_shared_moudle = DPKVSharedMoudle( + max_req_num=backend.args.running_max_req_size, + max_req_seq_len=backend.args.max_req_total_len + 8, + dp_size_in_node=backend.dp_size_in_node, + backend=backend, + ) + backend.node_nccl_group.barrier() + + # Collect mem_managers from all ranks + backend.mem_managers = [] + for rank_idx in range(backend.node_world_size): + if rank_idx != backend.rank_in_node: + backend.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, backend.rank_in_node)) + else: + backend.mem_managers.append(backend.model.mem_manager) + return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 3ba1dfe63..d43c910e3 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -68,27 +68,6 @@ def __init__(self) -> None: self.classed_req_strict_prefill = False return - def init_custom(self): - if self.enable_dp_prompt_cache_fetch: - torch.cuda.set_device(get_current_device_id()) - - self.dp_kv_shared_moudle = DPKVSharedMoudle( - max_req_num=self.args.running_max_req_size, - max_req_seq_len=self.args.max_req_total_len + 8, - dp_size_in_node=self.dp_size_in_node, - backend=self, - ) - dist.barrier(group=self.node_nccl_group) - - # Collect mem_managers from all ranks - self.mem_managers = [] - for rank_idx in range(self.node_world_size): - if rank_idx != self.rank_in_node: - self.mem_managers.append(MemoryManager.loads_from_shm(self.rank_in_node)) - else: - self.mem_managers.append(self.model.mem_manager) - return - def _init_reqs(self, reqs: List[Tuple]): if not self.args.enable_dp_prompt_cache_fetch: return super()._init_reqs(reqs) @@ -101,7 +80,7 @@ def _init_reqs(self, reqs: List[Tuple]): infer_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) req_dp_ranks = [req[3] for req in reqs] - self.dp_kv_shared_moudle.fill_reqs_info(reqs=infer_reqs, req_dp_ranks=req_dp_ranks) + self.dp_kv_shared_moudle.fill_reqs_info(reqs=infer_reqs) trans_taskes = self.dp_kv_shared_moudle.build_shared_kv_trans_tasks(reqs=infer_reqs, req_dp_ranks=req_dp_ranks) self.dp_kv_shared_moudle.kv_trans(trans_tasks=trans_taskes) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index b04cbb900..203d65c65 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -58,7 +58,8 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) + for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py deleted file mode 100644 index 62609c4c9..000000000 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/p2p_fix.py +++ /dev/null @@ -1,155 +0,0 @@ -# mypy: allow-untyped-defs -import multiprocessing -import os -import threading -from multiprocessing.reduction import ForkingPickler -from multiprocessing.util import register_after_fork -from typing import Union - -import torch -import torch.utils.hooks -from torch._namedtensor_internals import check_serializing_named_tensor -from torch.multiprocessing.reductions import storage_from_cache, shared_cache, StorageWeakRef -from torch.multiprocessing.reductions import reduce_nested_tensor, reduce_sparse_tensor, rebuild_tensor - - -def p2p_fix_rebuild_cuda_tensor( - tensor_cls, - tensor_size, - tensor_stride, - tensor_offset, - storage_cls, - dtype, - storage_device, - storage_handle, - storage_size_bytes, - storage_offset_bytes, - requires_grad, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, -): - # 因为接收进程在将 tensor 对应的 handle重新转化为指针的时候 - # 在其c++源码中会将当前显卡切换到storage_device再做操作,这样 - # 得到的指针可能不是接收进程当前上下文设备可以访问的,所以在这里 - # hack 修改了使用的 storage_device,这样后续tritonkernel同时 - # 访问几张显卡上的数据,进行p2p操作就不会出问题了。 - storage_device = torch.cuda.current_device() - # If storage_handle is None, storage points to nullptr. - if storage_handle is None or storage_size_bytes == 0: - storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) - else: - storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes)) - if storage is None: - torch.cuda._lazy_init() - storage = storage_cls._new_shared_cuda( - storage_device, - storage_handle, - storage_size_bytes, - storage_offset_bytes, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ) - shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage) - else: - # We already ref counting this Storage, but producer needs new ref-counters to be released. - storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) - - _storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage - - t = torch._utils._rebuild_tensor( - torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), - tensor_offset, - tensor_size, - tensor_stride, - ) - - if tensor_cls == torch.nn.parameter.Parameter: - # It is crucial for integer tensors to receive - # the requires_grad=False as an argument in the constructor - t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) - else: - t.requires_grad = requires_grad - - return t - - -def reduce_tensor(tensor): - if tensor.requires_grad and not tensor.is_leaf: - raise RuntimeError( - "Cowardly refusing to serialize non-leaf tensor which requires_grad, " - "since autograd does not support crossing process boundaries. " - "If you just want to transfer the data, call detach() on the tensor " - "before serializing (e.g., putting it on the queue)." - ) - - check_serializing_named_tensor(tensor) - torch.utils.hooks.warn_if_has_hooks(tensor) - - from torch.nested._internal.nested_tensor import NestedTensor - - if tensor.is_nested and not isinstance(tensor, NestedTensor): - return reduce_nested_tensor(tensor) - - if tensor.layout in { - torch.sparse_coo, - torch.sparse_csr, - torch.sparse_bsr, - torch.sparse_csc, - torch.sparse_bsc, - }: - return reduce_sparse_tensor(tensor) - - storage = tensor._typed_storage() - - if storage._untyped_storage.device.type == "cuda": - ( - device, - handle, - storage_size_bytes, - storage_offset_bytes, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ) = storage._share_cuda_() - tensor_offset = tensor.storage_offset() - shared_cache[handle] = StorageWeakRef(storage) - # _backward_hooks purposely omitted here, see - # Note [Don't serialize hooks] - from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import ( - p2p_fix_rebuild_cuda_tensor, - ) - - return ( - p2p_fix_rebuild_cuda_tensor, - ( - type(tensor), - tensor.size(), - tensor.stride(), - tensor_offset, # tensor offset in its storage - type(storage), - tensor.dtype, - device, - handle, # identifier which CUDA allocation is the storage in. - storage_size_bytes, # size(in bytes) of the storage - storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation - tensor.requires_grad, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ), - ) - - # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] - metadata = ( - tensor.storage_offset(), - tensor.size(), - tensor.stride(), - tensor.requires_grad, - ) - return (rebuild_tensor, (type(tensor), storage, metadata)) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py index 2c6c295bc..eed98399e 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py @@ -18,7 +18,6 @@ def __init__(self, info_queue: mp.Queue) -> None: def init_custom(self): NIXLChunckedPrefillForPrefillNode.init_custom(self) - super().init_custom() return def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 063ce5c6a..5f078806e 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -50,7 +50,8 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) + for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") From 53852d2e24b53c045ca31dcbf6be29e245c99f70 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Fri, 21 Nov 2025 07:45:47 +0000 Subject: [PATCH 27/34] fix position_ids empty --- lightllm/server/core/objs/req.py | 10 ----- .../dp_backend/dp_shared_kv_trans.py | 42 +++++++++---------- .../mode_backend/dp_backend/impl.py | 1 - 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 6ee68c52a..0d2e7ae38 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -122,12 +122,6 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), - # 所有DP中的最大kv cache的长度 - ("dp_max_kv_len", ctypes.c_int), - # 拥有最大kv cache长度的dp_rank - ("dp_max_kv_rank", ctypes.c_int), - # 原DP的kv len - ("dp_origin_kv_len", ctypes.c_int), ] def get_str(self): @@ -189,10 +183,6 @@ def init( self.post_init() self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size - # 初始化DP模式相关字段 - self.dp_max_kv_len = 0 - self.dp_max_kv_rank = -1 - self.dp_origin_kv_len = 0 if get_env_start_args().enable_cpu_cache: self._fill_input_token_hash() return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py index ccab6934d..780a0bcc3 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py @@ -37,7 +37,7 @@ def fill_reqs_info(self, reqs: List[InferReq]): 填充请求的 kv 信息到共享内存中 """ self.backend.node_nccl_group.barrier() - if self.backend.rank_in_dp == 0: + if self.backend.is_master_in_dp: self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._KV_LEN_INDEX] = [ req.cur_kv_len for req in reqs ] @@ -68,7 +68,8 @@ def build_shared_kv_trans_tasks( ): # 当前请求是本 dp_rank 负责的 is_current_dp_handle = req_dp_rank == self.dp_rank_in_node - trans_size = max_req_radix_cache_len - req.cur_kv_len + # 计算需要传输的 kv 长度, 不能超过 req.get_cur_total_len() - 1 + trans_size = min(max_req_radix_cache_len, req.get_cur_total_len() - 1) - req.cur_kv_len if is_current_dp_handle and trans_size > 0 and g_infer_context.get_can_alloc_token_num() > trans_size: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(trans_size) @@ -78,7 +79,7 @@ def build_shared_kv_trans_tasks( max_kv_len_mem_manager_index = max_kv_len_dp_rank * self.backend.dp_world_size + self.backend.rank_in_dp max_kv_len_mem_manager: MemoryManager = self.backend.mem_managers[max_kv_len_mem_manager_index] max_kv_len_mem_indexes = max_kv_len_mem_manager.req_to_token_indexs[ - max_kv_len_req_idx, req.cur_kv_len : max_req_radix_cache_len + max_kv_len_req_idx, req.cur_kv_len : req.cur_kv_len + trans_size ] trans_tasks.append( TransTask( @@ -143,22 +144,21 @@ class TransTask: def init_dp_kv_shared(backend): - if backend.enable_dp_prompt_cache_fetch: - torch.cuda.set_device(get_current_device_id()) - - backend.dp_kv_shared_moudle = DPKVSharedMoudle( - max_req_num=backend.args.running_max_req_size, - max_req_seq_len=backend.args.max_req_total_len + 8, - dp_size_in_node=backend.dp_size_in_node, - backend=backend, - ) - backend.node_nccl_group.barrier() - - # Collect mem_managers from all ranks - backend.mem_managers = [] - for rank_idx in range(backend.node_world_size): - if rank_idx != backend.rank_in_node: - backend.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, backend.rank_in_node)) - else: - backend.mem_managers.append(backend.model.mem_manager) + torch.cuda.set_device(get_current_device_id()) + + backend.dp_kv_shared_moudle = DPKVSharedMoudle( + max_req_num=backend.args.running_max_req_size, + max_req_seq_len=backend.args.max_req_total_len + 8, + dp_size_in_node=backend.dp_size_in_node, + backend=backend, + ) + backend.node_nccl_group.barrier() + + # Collect mem_managers from all ranks + backend.mem_managers = [] + for rank_idx in range(backend.node_world_size): + if rank_idx != backend.rank_in_node: + backend.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, backend.rank_in_node)) + else: + backend.mem_managers.append(backend.model.mem_manager) return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index d43c910e3..dc1893a86 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -36,7 +36,6 @@ def __init__(self) -> None: # 用于控制每一步是执行prefill 和 decode 还是跳过 self.control_state_machine = DPControlState(backend=self) - self.enable_dp_prompt_cache_fetch = get_env_start_args().enable_dp_prompt_cache_fetch # 在 mtp 模式下切换绑定的prefill 和 decode 函数 if get_env_start_args().mtp_mode: From 4e9a6c5ded3344d6038d4f0b450de4d6aad83f9b Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Mon, 24 Nov 2025 07:39:09 +0000 Subject: [PATCH 28/34] fix --- .../model_infer/mode_backend/base_backend.py | 26 +++++++++++- .../dp_backend/dp_shared_kv_trans.py | 42 ++++--------------- .../mode_backend/dp_backend/impl.py | 6 +-- 3 files changed, 36 insertions(+), 38 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 706adbb47..3a4892ebb 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -39,7 +39,6 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule -from .dp_backend.dp_shared_kv_trans import init_dp_kv_shared class ModeBackend: @@ -215,11 +214,12 @@ def init_model(self, kvargs): or self.args.enable_dp_prompt_cache_fetch ): self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager) + dist.barrier(group=self.node_nccl_group) self.init_custom() if self.args.enable_dp_prompt_cache_fetch: - init_dp_kv_shared(self) + self.init_dp_kv_shared() self.shm_reqs_io_buffer = ShmObjsIOBuffer() # 只会在 nixl pd 模式下才会使用,用于上传分块传输任务是否成功。 @@ -243,6 +243,28 @@ def init_model(self, kvargs): def init_custom(self): pass + def init_dp_kv_shared(self): + from lightllm.server.router.model_infer.mode_backend.dp_backend.dp_shared_kv_trans import DPKVSharedMoudle + from lightllm.common.mem_manager import MemoryManager + + torch.cuda.set_device(get_current_device_id()) + + self.dp_kv_shared_module = DPKVSharedMoudle( + max_req_num=self.args.running_max_req_size, + max_req_seq_len=self.args.max_req_total_len + 8, + dp_size_in_node=self.dp_size_in_node, + backend=self, + ) + + # Collect mem_managers from all ranks + self.mem_managers = [] + for rank_idx in range(self.node_world_size): + if rank_idx != self.rank_in_node: + self.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, self.rank_in_node)) + else: + self.mem_managers.append(self.model.mem_manager) + return + def get_max_total_token_num(self): return self.model.mem_manager.size diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py index 780a0bcc3..9164bd925 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py @@ -1,4 +1,5 @@ # 该文件用于提供在数据dp并行的推理模式下,共享kv cache trans相关的功能函数模块 +import time import numpy as np import dataclasses import torch @@ -9,6 +10,8 @@ from lightllm.server.core.objs.shm_array import ShmArray from ...infer_batch import InferReq from lightllm.utils.dist_utils import get_current_device_id +from lightllm.server.router.model_infer.infer_batch import g_infer_context +import torch.distributed as dist class DPKVSharedMoudle: @@ -36,7 +39,7 @@ def fill_reqs_info(self, reqs: List[InferReq]): """ 填充请求的 kv 信息到共享内存中 """ - self.backend.node_nccl_group.barrier() + dist.barrier(group=self.backend.node_nccl_group) if self.backend.is_master_in_dp: self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._KV_LEN_INDEX] = [ req.cur_kv_len for req in reqs @@ -54,9 +57,7 @@ def build_shared_kv_trans_tasks( """ 构建共享kv交换信息 """ - from lightllm.server.router.model_infer.infer_batch import g_infer_context - - self.backend.node_nccl_group.barrier() + dist.barrier(group=self.backend.node_nccl_group) trans_tasks: List[TransTask] = [] rank_max_radix_cache_lens = np.max( @@ -96,7 +97,6 @@ def build_shared_kv_trans_tasks( def kv_trans(self, trans_tasks: List["TransTask"]): from lightllm.server.router.model_infer.infer_batch import g_infer_context - self.backend.node_nccl_group.barrier() # kv 传输 if len(trans_tasks) > 0: max_kv_len_mem_indexes = [] @@ -104,14 +104,13 @@ def kv_trans(self, trans_tasks: List["TransTask"]): mem_indexes = [] for i, trans_task in enumerate(trans_tasks): - max_kv_len_mem_indexes.extend(trans_task.max_kv_len_mem_indexes) + max_kv_len_mem_indexes.append(trans_task.max_kv_len_mem_indexes) max_kv_len_dp_ranks.extend([trans_task.max_kv_len_dp_rank] * len(trans_task.max_kv_len_mem_indexes)) - mem_indexes.extend(trans_task.mem_indexes) + mem_indexes.append(trans_task.mem_indexes) - max_kv_len_mem_indexes_tensor = torch.tensor(max_kv_len_mem_indexes, dtype=torch.int64, device="cuda") + max_kv_len_mem_indexes_tensor = torch.cat(max_kv_len_mem_indexes).to(dtype=torch.int64, device="cuda") max_kv_len_dp_ranks_tensor = torch.tensor(max_kv_len_dp_ranks, dtype=torch.int32, device="cuda") - mem_indexes_tensor = torch.tensor(mem_indexes, dtype=torch.int64, device="cuda") - + mem_indexes_tensor = torch.cat(mem_indexes).to(dtype=torch.int64, device="cuda") self.backend.model.mem_manager.copy_kv_from_other_dp_ranks( mem_managers=self.backend.mem_managers, move_token_indexes=max_kv_len_mem_indexes_tensor, @@ -122,7 +121,6 @@ def kv_trans(self, trans_tasks: List["TransTask"]): ) self.backend.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {len(mem_indexes_tensor)}") - self.backend.node_nccl_group.barrier() for trans_task in trans_tasks: g_infer_context.req_manager.req_to_token_indexs[ trans_task.req.req_idx, @@ -131,7 +129,6 @@ def kv_trans(self, trans_tasks: List["TransTask"]): trans_task.req.cur_kv_len += len(trans_task.mem_indexes) if self.backend.is_master_in_dp: trans_task.req.shm_req.shm_cur_kv_len = trans_task.req.cur_kv_len - self.backend.node_nccl_group.barrier() @dataclasses.dataclass @@ -141,24 +138,3 @@ class TransTask: max_kv_len_dp_rank: int max_kv_len_mem_manager_index: int max_kv_len_mem_indexes: torch.Tensor - - -def init_dp_kv_shared(backend): - torch.cuda.set_device(get_current_device_id()) - - backend.dp_kv_shared_moudle = DPKVSharedMoudle( - max_req_num=backend.args.running_max_req_size, - max_req_seq_len=backend.args.max_req_total_len + 8, - dp_size_in_node=backend.dp_size_in_node, - backend=backend, - ) - backend.node_nccl_group.barrier() - - # Collect mem_managers from all ranks - backend.mem_managers = [] - for rank_idx in range(backend.node_world_size): - if rank_idx != backend.rank_in_node: - backend.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, backend.rank_in_node)) - else: - backend.mem_managers.append(backend.model.mem_manager) - return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index dc1893a86..714c29722 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -79,9 +79,9 @@ def _init_reqs(self, reqs: List[Tuple]): infer_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True) req_dp_ranks = [req[3] for req in reqs] - self.dp_kv_shared_moudle.fill_reqs_info(reqs=infer_reqs) - trans_taskes = self.dp_kv_shared_moudle.build_shared_kv_trans_tasks(reqs=infer_reqs, req_dp_ranks=req_dp_ranks) - self.dp_kv_shared_moudle.kv_trans(trans_tasks=trans_taskes) + self.dp_kv_shared_module.fill_reqs_info(reqs=infer_reqs) + trans_taskes = self.dp_kv_shared_module.build_shared_kv_trans_tasks(reqs=infer_reqs, req_dp_ranks=req_dp_ranks) + self.dp_kv_shared_module.kv_trans(trans_tasks=trans_taskes) g_infer_context._filter(finished_request_ids=[req[0] for req in other_dp_reqs]) g_infer_state_lock.release() From 704830fe2c65456a0893628cf2abe4c1627b82e7 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Wed, 26 Nov 2025 07:26:01 +0000 Subject: [PATCH 29/34] add news --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 364e96201..1ba4e6967 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram [English Docs](https://lightllm-en.readthedocs.io/en/latest/) | [中文文档](https://lightllm-cn.readthedocs.io/en/latest/) | [Blogs](https://modeltc.github.io/lightllm-blog/) ## News +- [2025/11] 🚀 Prefix KV Cache Transfer between DP rankers is now supported! Check out the technical deep dive in our [blog post](https://light-ai.top/lightllm-blog/2025/11/18/dp_kv_fetch.html). - [2025/09] 🔥 LightLLM [v1.1.0](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html) release! - [2025/08] Pre $^3$ achieves the outstanding paper award of [ACL2025](https://2025.aclweb.org/program/awards/). - [2025/05] LightLLM paper on constrained decoding accepted by [ACL2025](https://arxiv.org/pdf/2506.03887) (Pre $^3$: Enabling Deterministic Pushdown Automata for Faster Structured LLM Generation). For a more accessible overview of the research with key insights and examples, check out our blog post: [LightLLM Blog](https://www.light-ai.top/lightllm-blog/2025/06/15/pre3.html) From 796f036027e8d18e1ae9b3e19af1a2d392afb5a2 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Thu, 4 Dec 2025 07:35:42 +0000 Subject: [PATCH 30/34] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ba4e6967..8ff44b564 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram ## Performance -Learn more in the release blogs: [v1.0.0 blog](https://www.light-ai.top/lightllm-blog//by%20mtc%20team/2025/02/16/lightllm/). +Learn more in the release blogs: [v1.1.0 blog](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html). ## FAQ From 1fe6cfad89ebab24d83a6ff1c5621b6549175e25 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Thu, 4 Dec 2025 08:33:52 +0000 Subject: [PATCH 31/34] update readme --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8ff44b564..fd8a29c0f 100644 --- a/README.md +++ b/README.md @@ -20,11 +20,14 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram [English Docs](https://lightllm-en.readthedocs.io/en/latest/) | [中文文档](https://lightllm-cn.readthedocs.io/en/latest/) | [Blogs](https://modeltc.github.io/lightllm-blog/) -## News +## Tech Blogs - [2025/11] 🚀 Prefix KV Cache Transfer between DP rankers is now supported! Check out the technical deep dive in our [blog post](https://light-ai.top/lightllm-blog/2025/11/18/dp_kv_fetch.html). +- [2025/05] LightLLM paper on constrained decoding accepted by [ACL2025](https://arxiv.org/pdf/2506.03887) (Pre $^3$: Enabling Deterministic Pushdown Automata for Faster Structured LLM Generation). For a more accessible overview of the research with key insights and examples, check out our blog post: [LightLLM Blog](https://www.light-ai.top/lightllm-blog/2025/06/15/pre3.html) + +## News + - [2025/09] 🔥 LightLLM [v1.1.0](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html) release! - [2025/08] Pre $^3$ achieves the outstanding paper award of [ACL2025](https://2025.aclweb.org/program/awards/). -- [2025/05] LightLLM paper on constrained decoding accepted by [ACL2025](https://arxiv.org/pdf/2506.03887) (Pre $^3$: Enabling Deterministic Pushdown Automata for Faster Structured LLM Generation). For a more accessible overview of the research with key insights and examples, check out our blog post: [LightLLM Blog](https://www.light-ai.top/lightllm-blog/2025/06/15/pre3.html) - [2025/04] LightLLM paper on request scheduler published in [ASPLOS’25](https://dl.acm.org/doi/10.1145/3676641.3716011) (Past-Future Scheduler for LLM Serving under SLA Guarantees) - [2025/02] 🔥 LightLLM v1.0.0 release, achieving the **fastest DeepSeek-R1** serving performance on single H200 machine. From 0e13fb93a3abba699975ee094d7702291b17b825 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Dec 2025 04:52:08 +0000 Subject: [PATCH 32/34] fix --- .../kv_cache_mem_manager/mem_manager.py | 41 +++++++++++++------ .../model_infer/mode_backend/base_backend.py | 4 +- .../decode_node_impl/decode_trans_process.py | 3 +- .../prefill_trans_process.py | 3 +- .../decode_node_impl/decode_trans_process.py | 3 +- .../prefill_trans_process.py | 3 +- 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 332fa6dcb..d8fd93009 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -19,6 +19,7 @@ from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm from multiprocessing.reduction import ForkingPickler +from filelock import FileLock logger = init_logger(__name__) @@ -450,25 +451,39 @@ def write_to_shm(self, req_manager): # 避免过多无用的数据复制和传输开销。 self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs - shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" - for rank_in_node in range(0, get_node_world_size() * 2): - obj_bytes = ForkingPickler.dumps(self).tobytes() + lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock") + with lock: + node_world_size = get_node_world_size() + shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" + obj_bytes_array = [ForkingPickler.dumps(self).tobytes() for _ in range(node_world_size * 2)] + obj_size = len(obj_bytes_array[0]) shm = create_or_link_shm( - name=f"{shm_name}_{rank_in_node}", expected_size=len(obj_bytes) + 4, force_mode="create" + name=shm_name, expected_size=obj_size * (node_world_size * 2) + 4 + 4, force_mode="create" ) logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer") - shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little") - shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes + shm.buf[0:4] = (node_world_size * 2).to_bytes(4, "little") + shm.buf[4:8] = obj_size.to_bytes(4, "little") + start_index = 8 + for obj_bytes in obj_bytes_array: + shm.buf[start_index : start_index + obj_size] = obj_bytes + start_index += obj_size @staticmethod - def loads_from_shm(rank_in_node: int, current_rank_in_node: int) -> "MemoryManager": - shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}_{current_rank_in_node}" + def loads_from_shm(rank_in_node: int) -> "MemoryManager": + shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}" + lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock") logger.info(f"get memmanager from shm {shm_name}") - shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link") - bytes_len = int.from_bytes(shm.buf[0:4], "little") - obj_bytes = shm.buf[4 : 4 + bytes_len].tobytes() - shm.close() - return ForkingPickler.loads(obj_bytes) + with lock: + shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link") + left_num = int.from_bytes(shm.buf[0:4], "little") + obj_size = int.from_bytes(shm.buf[4:8], "little") + assert left_num > 0 + end_index = 8 + left_num * obj_size + start_index = 8 + (left_num - 1) * obj_size + obj_bytes = shm.buf[start_index:end_index].tobytes() + shm.buf[0:4] = (left_num - 1).to_bytes(4, byteorder="little") + shm.close() + return ForkingPickler.loads(obj_bytes) class ReadOnlyStaticsMemoryManager: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 3a4892ebb..415c23eb3 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -245,7 +245,7 @@ def init_custom(self): def init_dp_kv_shared(self): from lightllm.server.router.model_infer.mode_backend.dp_backend.dp_shared_kv_trans import DPKVSharedMoudle - from lightllm.common.mem_manager import MemoryManager + from lightllm.common.kv_cache_mem_manager import MemoryManager torch.cuda.set_device(get_current_device_id()) @@ -260,7 +260,7 @@ def init_dp_kv_shared(self): self.mem_managers = [] for rank_idx in range(self.node_world_size): if rank_idx != self.rank_in_node: - self.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, self.rank_in_node)) + self.mem_managers.append(MemoryManager.loads_from_shm(rank_idx)) else: self.mem_managers.append(self.model.mem_manager) return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 7023176e8..cdca63887 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -114,8 +114,7 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) - for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index f411e96a6..a328e3e08 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -119,8 +119,7 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) - for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index 203d65c65..b04cbb900 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -58,8 +58,7 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) - for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 5f078806e..063ce5c6a 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -50,8 +50,7 @@ def _init_env( # 从共享内存读取所有rank的mem_manager node_world_size = args.tp // args.nnodes mem_managers: List[MemoryManager] = [ - MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size) - for rank in range(node_world_size) + MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size) ] task_out_queue.put("get_mem_managers_ok") From 956796c842cb65f21f7ca775c769b9c8512ee8c1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Dec 2025 05:38:26 +0000 Subject: [PATCH 33/34] fix --- README.md | 3 +-- .../server/router/model_infer/mode_backend/base_backend.py | 5 ++--- .../mode_backend/dp_backend/dp_shared_kv_trans.py | 2 +- .../router/model_infer/mode_backend/dp_backend/impl.py | 5 ----- 4 files changed, 4 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index fd8a29c0f..8137a1b47 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,11 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram ## Tech Blogs - [2025/11] 🚀 Prefix KV Cache Transfer between DP rankers is now supported! Check out the technical deep dive in our [blog post](https://light-ai.top/lightllm-blog/2025/11/18/dp_kv_fetch.html). -- [2025/05] LightLLM paper on constrained decoding accepted by [ACL2025](https://arxiv.org/pdf/2506.03887) (Pre $^3$: Enabling Deterministic Pushdown Automata for Faster Structured LLM Generation). For a more accessible overview of the research with key insights and examples, check out our blog post: [LightLLM Blog](https://www.light-ai.top/lightllm-blog/2025/06/15/pre3.html) ## News - - [2025/09] 🔥 LightLLM [v1.1.0](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html) release! - [2025/08] Pre $^3$ achieves the outstanding paper award of [ACL2025](https://2025.aclweb.org/program/awards/). +- [2025/05] LightLLM paper on constrained decoding accepted by [ACL2025](https://arxiv.org/pdf/2506.03887) (Pre $^3$: Enabling Deterministic Pushdown Automata for Faster Structured LLM Generation). For a more accessible overview of the research with key insights and examples, check out our blog post: [LightLLM Blog](https://www.light-ai.top/lightllm-blog/2025/06/15/pre3.html) - [2025/04] LightLLM paper on request scheduler published in [ASPLOS’25](https://dl.acm.org/doi/10.1145/3676641.3716011) (Past-Future Scheduler for LLM Serving under SLA Guarantees) - [2025/02] 🔥 LightLLM v1.0.0 release, achieving the **fastest DeepSeek-R1** serving performance on single H200 machine. diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 415c23eb3..a780c4da0 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -213,6 +213,8 @@ def init_model(self, kvargs): self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"] or self.args.enable_dp_prompt_cache_fetch ): + # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 + # 读取 self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager) dist.barrier(group=self.node_nccl_group) @@ -229,9 +231,6 @@ def init_model(self, kvargs): if self.args.mtp_mode: self.init_mtp_draft_model(kvargs) - # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便 - # 读取 - # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py index 9164bd925..5de90bef6 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py @@ -4,7 +4,7 @@ import dataclasses import torch from typing import List -from lightllm.common.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.utils.dist_utils import get_dp_rank_in_node from lightllm.server.core.objs.shm_array import ShmArray diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 714c29722..734dc9998 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -1,11 +1,8 @@ import torch import time -import numpy as np -import os import torch.nn.functional as F import torch.distributed as dist from typing import List, Tuple, Optional, Callable -from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.server.router.model_infer.infer_batch import InferSamplingParams, g_infer_context, InferReq @@ -26,8 +23,6 @@ from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids from .control_state import DPControlState -from lightllm.common.mem_manager import MemoryManager -from .dp_shared_kv_trans import DPKVSharedMoudle class DPChunkedPrefillBackend(ModeBackend): From 4cc963ea788f87df74a86fe684477ee6bf72f686 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Dec 2025 06:05:13 +0000 Subject: [PATCH 34/34] fix --- test.py | 51 --------------------------------------------------- 1 file changed, 51 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index 689ae9b77..000000000 --- a/test.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import multiprocessing -from multiprocessing.reduction import ForkingPickler - - -def worker(serialized): - torch.cuda.set_device(0) - tensor = ForkingPickler.loads(serialized) - print("In worker process:", tensor, type(tensor)) - # import time - # time.sleep(100) - - -def worker1(serialized): - torch.cuda.set_device(1) - tensor = ForkingPickler.loads(serialized) - print("In worker process:", tensor, type(tensor)) - # import time - # time.sleep(100) - - -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") - - torch.cuda.set_device(0) - # Create a tensor on the CUDA device - a = torch.zeros((100,), device="cuda").cuda() - - same = ForkingPickler.dumps(a) - serialized = same.tobytes() - - # Create a new process - process = multiprocessing.Process(target=worker, args=(serialized,)) - - # Start the process - process.start() - - process1 = multiprocessing.Process(target=worker1, args=(serialized,)) - - # Start the process - process1.start() - - # Wait for the process to finish - process.join() - process1.join() - print(a) - import time - - time.sleep(10) - - print("Main process finished.")