Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
153a409
fix
Dec 9, 2025
962ef2b
fix
Dec 9, 2025
0cd2b86
add enable_dp_prompt_cache_fetch
WANDY666 Nov 3, 2025
0667d71
free_radix_cache_to_get_enough_token instead of skip
WANDY666 Nov 4, 2025
5a1e22d
add use_openai_api, port, concurrency, history-turns, max-total-token…
WANDY666 Nov 4, 2025
7551fb7
support pd split
WANDY666 Nov 4, 2025
a3c44be
add mem_queues
WANDY666 Nov 4, 2025
1323ce1
little update
WANDY666 Nov 5, 2025
b72b7ac
fix
Dec 9, 2025
518b8f1
fix
Dec 9, 2025
da2eb3d
use node_nccl_group
WANDY666 Nov 6, 2025
737218e
delete shm_kv_indexes add shared_kv_indexes to reduce shared memory u…
WANDY666 Nov 7, 2025
b0743a7
layer into triton op
WANDY666 Nov 10, 2025
26879ea
fix multiple visits to fd
WANDY666 Nov 10, 2025
725eec3
fix pd mem_manager get failed
WANDY666 Nov 11, 2025
dad0b83
fix release other shm_reqs
WANDY666 Nov 11, 2025
6cc9982
add use_for_pd_trans to avoid duplicate name overwriting
WANDY666 Nov 11, 2025
a97df66
minor change
WANDY666 Nov 14, 2025
47212f0
add test.py
hiworldwzj Nov 15, 2025
cbb1b84
improve mem_manager
hiworldwzj Nov 15, 2025
b548946
write mem manager to shm
hiworldwzj Nov 15, 2025
6270e0b
fix
hiworldwzj Nov 15, 2025
e1769a3
fix
hiworldwzj Nov 17, 2025
c4f780f
fix
Nov 18, 2025
1e1e18a
fix
Nov 18, 2025
c0f567e
fix
WANDY666 Nov 21, 2025
53852d2
fix position_ids empty
WANDY666 Nov 21, 2025
4e9a6c5
fix
WANDY666 Nov 24, 2025
704830f
add news
WANDY666 Nov 26, 2025
796f036
update readme
WANDY666 Dec 4, 2025
1fe6cfa
update readme
WANDY666 Dec 4, 2025
0e13fb9
fix
Dec 9, 2025
956796c
fix
Dec 9, 2025
4cc963e
fix
Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram

[English Docs](https://lightllm-en.readthedocs.io/en/latest/) | [中文文档](https://lightllm-cn.readthedocs.io/en/latest/) | [Blogs](https://modeltc.github.io/lightllm-blog/)

## Tech Blogs
- [2025/11] 🚀 Prefix KV Cache Transfer between DP rankers is now supported! Check out the technical deep dive in our [blog post](https://light-ai.top/lightllm-blog/2025/11/18/dp_kv_fetch.html).

## News
- [2025/09] 🔥 LightLLM [v1.1.0](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html) release!
- [2025/08] Pre $^3$ achieves the outstanding paper award of [ACL2025](https://2025.aclweb.org/program/awards/).
Expand All @@ -36,7 +39,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram

## Performance

Learn more in the release blogs: [v1.0.0 blog](https://www.light-ai.top/lightllm-blog//by%20mtc%20team/2025/02/16/lightllm/).
Learn more in the release blogs: [v1.1.0 blog](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html).

## FAQ

Expand Down
4 changes: 2 additions & 2 deletions docs/CN/source/tutorial/deepseek_deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
export host=$1
export pd_master_ip=$2
nvidia-cuda-mps-control -d
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
--model_dir /path/DeepSeek-R1 \
--run_mode "prefill" \
--tp 8 \
Expand All @@ -211,7 +211,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
export host=$1
export pd_master_ip=$2
nvidia-cuda-mps-control -d
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
--model_dir /path/DeepSeek-R1 \
--run_mode "decode" \
--tp 8 \
Expand Down
4 changes: 2 additions & 2 deletions docs/EN/source/tutorial/deepseek_deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for
export host=$1
export pd_master_ip=$2
nvidia-cuda-mps-control -d
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
--model_dir /path/DeepSeek-R1 \
--run_mode "prefill" \
--tp 8 \
Expand All @@ -208,7 +208,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for
export host=$1
export pd_master_ip=$2
nvidia-cuda-mps-control -d
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
--model_dir /path/DeepSeek-R1 \
--run_mode "decode" \
--tp 8 \
Expand Down
86 changes: 85 additions & 1 deletion lightllm/common/kv_cache_mem_manager/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from typing import List, Union
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp
from lightllm.server.pd_io_struct import KVMoveTask
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
from lightllm.distributed.pynccl import PyNcclCommunicator
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.config_utils import get_num_key_value_heads
from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io
from lightllm.utils.device_utils import kv_trans_use_p2p
from lightllm.utils.shm_utils import create_or_link_shm
from multiprocessing.reduction import ForkingPickler
from filelock import FileLock

logger = init_logger(__name__)

Expand Down Expand Up @@ -401,6 +407,84 @@ def get_index_kv_buffer(self, index):
def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])

def copy_kv_from_other_dp_ranks(
self,
mem_managers: List["MemoryManager"],
move_token_indexes: torch.Tensor,
token_dp_indexes: torch.Tensor,
mem_indexes: torch.Tensor,
dp_size_in_node: int,
rank_in_dp: int,
):
if not hasattr(self, "mem_ptrs_tensor"):
# 构建一个2D tensor,shape为(layer_num, mem_num)
mems_ptr_list = []
for i in range(0, len(mem_managers)):
mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr())
self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True)

# 一次性传输所有层
kv_trans_for_dp(
input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True),
input_idx=move_token_indexes,
input_dp_idx=token_dp_indexes,
output=self.kv_buffer,
output_idx=mem_indexes,
dp_size_in_node=dp_size_in_node,
rank_in_dp=rank_in_dp,
)

def write_to_shm(self, req_manager):
"""
将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。
"""
if kv_trans_use_p2p():
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor

mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__

from lightllm.common.req_manager import ReqManager

req_manager: ReqManager = req_manager

# 这个地方是一个不太优雅的设计,但是暂时这么做,可以让dp shared kv swap模块直接访问 req_manager 中的 req_to_token_indexs
# 避免过多无用的数据复制和传输开销。
self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs

lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock")
with lock:
node_world_size = get_node_world_size()
shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}"
obj_bytes_array = [ForkingPickler.dumps(self).tobytes() for _ in range(node_world_size * 2)]
obj_size = len(obj_bytes_array[0])
shm = create_or_link_shm(
name=shm_name, expected_size=obj_size * (node_world_size * 2) + 4 + 4, force_mode="create"
)
logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer")
shm.buf[0:4] = (node_world_size * 2).to_bytes(4, "little")
shm.buf[4:8] = obj_size.to_bytes(4, "little")
start_index = 8
for obj_bytes in obj_bytes_array:
shm.buf[start_index : start_index + obj_size] = obj_bytes
start_index += obj_size

@staticmethod
def loads_from_shm(rank_in_node: int) -> "MemoryManager":
shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}"
lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock")
logger.info(f"get memmanager from shm {shm_name}")
with lock:
shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link")
left_num = int.from_bytes(shm.buf[0:4], "little")
obj_size = int.from_bytes(shm.buf[4:8], "little")
assert left_num > 0
end_index = 8 + left_num * obj_size
start_index = 8 + (left_num - 1) * obj_size
obj_bytes = shm.buf[start_index:end_index].tobytes()
shm.buf[0:4] = (left_num - 1).to_bytes(4, byteorder="little")
shm.close()
return ForkingPickler.loads(obj_bytes)


class ReadOnlyStaticsMemoryManager:
"""
Expand Down
112 changes: 112 additions & 0 deletions lightllm/common/kv_trans_kernel/kv_trans_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,115 @@ def kv_trans_v2_for_d_node(
num_warps=1,
)
return


@triton.jit
def _kv_trans_for_dp_kernel(
input_mems_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
Comment on lines +199 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These stride parameters (input_stride_1, input_stride_2) are either unused or their values are not used after being cast. This makes the kernel signature more complex than necessary. The same applies to output_stride_1 and output_stride_2 on lines 206-207. For improved clarity and maintainability, it's recommended to remove these unused parameters. Consequently, the call to this kernel in kv_trans_for_dp should be updated to pass only the required strides (e.g., output.stride(0)) instead of unpacking all strides with *output.stride().

input_stride_3,
input_token_idx_ptr,
input_token_dp_index_ptr,
output_ptr,
output_stride_0,
output_stride_1,
output_stride_2,
output_stride_3,
output_token_idx_ptr,
layer_num: tl.constexpr,
token_num: int,
head_num: int,
head_dim: int,
grid_count: int,
BLOCK_SIZE: tl.constexpr,
NUM_STAGES: tl.constexpr,
CARD_NUM_PER_D: tl.constexpr,
RANK_IN_DP: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)

head_num_dim = head_num * head_dim
tid = tl.program_id(0)

offs = tl.arange(0, BLOCK_SIZE)
while tid < token_num:
dp_index = tl.load(input_token_dp_index_ptr + tid)
mem_index = RANK_IN_DP + dp_index * CARD_NUM_PER_D
input_token_idx = tl.load(input_token_idx_ptr + tid)
output_token_idx = tl.load(output_token_idx_ptr + tid)

input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty))
for layer_idx in tl.range(0, layer_num, 1):
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
cur_offs = block_idx * BLOCK_SIZE + offs
in_datas = tl.load(
input_ptr + input_stride_0 * layer_idx + input_stride_1 * input_token_idx + cur_offs,
mask=cur_offs < head_num_dim,
)
tl.store(
output_ptr + output_stride_0 * layer_idx + output_stride_1 * output_token_idx + cur_offs,
in_datas,
mask=cur_offs < head_num_dim,
)

tid += grid_count

return


def kv_trans_for_dp(
input_mems: torch.Tensor,
input_idx: torch.Tensor,
input_dp_idx: torch.Tensor,
output: torch.Tensor,
output_idx: torch.Tensor,
dp_size_in_node: int,
rank_in_dp: int,
):
"""
input_mems 是一个 torch.uint64 的tensor, shape为(layer_num, mem_num),其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
"""
assert input_mems.is_contiguous()
assert output.is_contiguous()
assert len(input_mems.shape) == 1
assert len(output.shape) == 4
assert len(input_idx) == len(output_idx)
assert len(output_idx) == len(input_dp_idx)
assert len(input_mems) % dp_size_in_node == 0

card_num_per_d = len(input_mems) // dp_size_in_node

layer_num, _, head_num, head_dim = output.shape
token_num = len(output_idx)
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
grid_count = 20
BLOCK_SIZE = 256
NUM_STAGES = 3
grid = (grid_count,)

_kv_trans_for_dp_kernel[grid](
input_mems,
*output.stride(),
input_idx,
input_dp_idx,
output,
*output.stride(),
output_idx,
layer_num=layer_num,
token_num=token_num,
head_num=head_num,
head_dim=head_dim,
grid_count=grid_count,
BLOCK_SIZE=BLOCK_SIZE,
NUM_STAGES=NUM_STAGES,
CARD_NUM_PER_D=card_num_per_d,
RANK_IN_DP=rank_in_dp,
num_warps=1,
)

return
6 changes: 6 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,4 +566,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="""Directory used to persist disk cache data. Defaults to a temp directory when not set.""",
)
parser.add_argument(
"--enable_dp_prompt_cache_fetch",
action="store_true",
default=False,
help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""",
)
return parser
8 changes: 8 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ def normal_or_p_d_start(args):
args.router_max_wait_tokens = 0

send_and_receive_node_ip(args) # 多机用于收发node ip
# dp 必须 > 1
if args.enable_dp_prompt_cache_fetch and args.dp <= 1:
args.enable_dp_prompt_cache_fetch = False
logger.warning(
"""dp <= 1 does not support dp_prompt_cache_fetch;
overriding enable_dp_prompt_cache_fetch to False"""
)

set_env_start_args(args)
logger.info(f"all start args:{args}")

Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class StartArgs:
enable_disk_cache: bool = field(default=False)
disk_cache_storage_size: float = field(default=10)
disk_cache_dir: Optional[str] = field(default=None)
enable_dp_prompt_cache_fetch: bool = field(default=False)
# zmp ports
router_port: int = field(default=None)
detokenization_port: int = field(default=None)
Expand Down
12 changes: 4 additions & 8 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ async def wait_to_model_ready(self):
self.model_rpc_servers = []
# 用于 kv move 管理进程 和 推理进程进行task信息的交互。
self.info_queue: mp.Queue = mp.Queue()
self.mem_queues: List[torch.multiprocessing.Queue] = [
torch.multiprocessing.Queue() for _ in range(self.node_world_size)
]
self.rpc_event = multiprocessing.Event()
self.rpc_finished_event = multiprocessing.Event()

Expand All @@ -137,7 +134,6 @@ async def wait_to_model_ready(self):
rpc_event=self.rpc_event,
rpc_finished_event=self.rpc_finished_event,
info_queue=self.info_queue,
mem_queue=self.mem_queues[(rank_id % node_world_size)],
router_lock=self.router_lock,
)
)
Expand Down Expand Up @@ -205,29 +201,29 @@ async def wait_to_model_ready(self):
start_prefill_kv_move_manager_process,
)

start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
start_prefill_kv_move_manager_process(self.args, self.info_queue)

if self.args.run_mode == "nixl_prefill":
from lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl import (
start_prefill_kv_move_manager_process,
)

start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
start_prefill_kv_move_manager_process(self.args, self.info_queue)

if self.args.run_mode == "decode":
# 启动 decode kv move 管理进程
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import (
start_decode_kv_move_manager_process,
)

start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
start_decode_kv_move_manager_process(self.args, self.info_queue)

if self.args.run_mode == "nixl_decode":
from lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl import (
start_decode_kv_move_manager_process,
)

start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
start_decode_kv_move_manager_process(self.args, self.info_queue)

return

Expand Down
Loading