Skip to content

Commit aff4049

Browse files
authored
[feature] Add prefix_kv_cache transfer between dp rankers. (#1093)
1 parent f5e07af commit aff4049

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+816
-406
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
2020

2121
[English Docs](https://lightllm-en.readthedocs.io/en/latest/) | [中文文档](https://lightllm-cn.readthedocs.io/en/latest/) | [Blogs](https://modeltc.github.io/lightllm-blog/)
2222

23+
## Tech Blogs
24+
- [2025/11] 🚀 Prefix KV Cache Transfer between DP rankers is now supported! Check out the technical deep dive in our [blog post](https://light-ai.top/lightllm-blog/2025/11/18/dp_kv_fetch.html).
25+
2326
## News
2427
- [2025/09] 🔥 LightLLM [v1.1.0](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html) release!
2528
- [2025/08] Pre $^3$ achieves the outstanding paper award of [ACL2025](https://2025.aclweb.org/program/awards/).
@@ -36,7 +39,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
3639

3740
## Performance
3841

39-
Learn more in the release blogs: [v1.0.0 blog](https://www.light-ai.top/lightllm-blog//by%20mtc%20team/2025/02/16/lightllm/).
42+
Learn more in the release blogs: [v1.1.0 blog](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html).
4043

4144
## FAQ
4245

docs/CN/source/tutorial/deepseek_deployment.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
187187
export host=$1
188188
export pd_master_ip=$2
189189
nvidia-cuda-mps-control -d
190-
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
190+
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
191191
--model_dir /path/DeepSeek-R1 \
192192
--run_mode "prefill" \
193193
--tp 8 \
@@ -211,7 +211,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以
211211
export host=$1
212212
export pd_master_ip=$2
213213
nvidia-cuda-mps-control -d
214-
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
214+
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
215215
--model_dir /path/DeepSeek-R1 \
216216
--run_mode "decode" \
217217
--tp 8 \

docs/EN/source/tutorial/deepseek_deployment.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for
187187
export host=$1
188188
export pd_master_ip=$2
189189
nvidia-cuda-mps-control -d
190-
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
190+
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
191191
--model_dir /path/DeepSeek-R1 \
192192
--run_mode "prefill" \
193193
--tp 8 \
@@ -208,7 +208,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for
208208
export host=$1
209209
export pd_master_ip=$2
210210
nvidia-cuda-mps-control -d
211-
MOE_MODE=EP KV_TRANS_USE_P2P=1 LOADWORKER=18 python -m lightllm.server.api_server \
211+
MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \
212212
--model_dir /path/DeepSeek-R1 \
213213
--run_mode "decode" \
214214
--tp 8 \

lightllm/common/kv_cache_mem_manager/mem_manager.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,24 @@
22
import os
33
import torch
44
import torch.distributed as dist
5+
import torch.multiprocessing as mp
56
from typing import List, Union
7+
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp
68
from lightllm.server.pd_io_struct import KVMoveTask
79
from lightllm.utils.log_utils import init_logger
810
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
911
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
1012
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
11-
from lightllm.utils.dist_utils import get_current_rank_in_node
13+
from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size
1214
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1315
from lightllm.distributed.pynccl import PyNcclCommunicator
1416
from lightllm.utils.dist_utils import get_current_device_id
1517
from lightllm.utils.config_utils import get_num_key_value_heads
1618
from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io
19+
from lightllm.utils.device_utils import kv_trans_use_p2p
20+
from lightllm.utils.shm_utils import create_or_link_shm
21+
from multiprocessing.reduction import ForkingPickler
22+
from filelock import FileLock
1723

1824
logger = init_logger(__name__)
1925

@@ -401,6 +407,84 @@ def get_index_kv_buffer(self, index):
401407
def load_index_kv_buffer(self, index, load_tensor_dict):
402408
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
403409

410+
def copy_kv_from_other_dp_ranks(
411+
self,
412+
mem_managers: List["MemoryManager"],
413+
move_token_indexes: torch.Tensor,
414+
token_dp_indexes: torch.Tensor,
415+
mem_indexes: torch.Tensor,
416+
dp_size_in_node: int,
417+
rank_in_dp: int,
418+
):
419+
if not hasattr(self, "mem_ptrs_tensor"):
420+
# 构建一个2D tensor,shape为(layer_num, mem_num)
421+
mems_ptr_list = []
422+
for i in range(0, len(mem_managers)):
423+
mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr())
424+
self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True)
425+
426+
# 一次性传输所有层
427+
kv_trans_for_dp(
428+
input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True),
429+
input_idx=move_token_indexes,
430+
input_dp_idx=token_dp_indexes,
431+
output=self.kv_buffer,
432+
output_idx=mem_indexes,
433+
dp_size_in_node=dp_size_in_node,
434+
rank_in_dp=rank_in_dp,
435+
)
436+
437+
def write_to_shm(self, req_manager):
438+
"""
439+
将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。
440+
"""
441+
if kv_trans_use_p2p():
442+
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor
443+
444+
mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__
445+
446+
from lightllm.common.req_manager import ReqManager
447+
448+
req_manager: ReqManager = req_manager
449+
450+
# 这个地方是一个不太优雅的设计,但是暂时这么做,可以让dp shared kv swap模块直接访问 req_manager 中的 req_to_token_indexs
451+
# 避免过多无用的数据复制和传输开销。
452+
self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs
453+
454+
lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock")
455+
with lock:
456+
node_world_size = get_node_world_size()
457+
shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}"
458+
obj_bytes_array = [ForkingPickler.dumps(self).tobytes() for _ in range(node_world_size * 2)]
459+
obj_size = len(obj_bytes_array[0])
460+
shm = create_or_link_shm(
461+
name=shm_name, expected_size=obj_size * (node_world_size * 2) + 4 + 4, force_mode="create"
462+
)
463+
logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer")
464+
shm.buf[0:4] = (node_world_size * 2).to_bytes(4, "little")
465+
shm.buf[4:8] = obj_size.to_bytes(4, "little")
466+
start_index = 8
467+
for obj_bytes in obj_bytes_array:
468+
shm.buf[start_index : start_index + obj_size] = obj_bytes
469+
start_index += obj_size
470+
471+
@staticmethod
472+
def loads_from_shm(rank_in_node: int) -> "MemoryManager":
473+
shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}"
474+
lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock")
475+
logger.info(f"get memmanager from shm {shm_name}")
476+
with lock:
477+
shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link")
478+
left_num = int.from_bytes(shm.buf[0:4], "little")
479+
obj_size = int.from_bytes(shm.buf[4:8], "little")
480+
assert left_num > 0
481+
end_index = 8 + left_num * obj_size
482+
start_index = 8 + (left_num - 1) * obj_size
483+
obj_bytes = shm.buf[start_index:end_index].tobytes()
484+
shm.buf[0:4] = (left_num - 1).to_bytes(4, byteorder="little")
485+
shm.close()
486+
return ForkingPickler.loads(obj_bytes)
487+
404488

405489
class ReadOnlyStaticsMemoryManager:
406490
"""

lightllm/common/kv_trans_kernel/kv_trans_v2.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,115 @@ def kv_trans_v2_for_d_node(
191191
num_warps=1,
192192
)
193193
return
194+
195+
196+
@triton.jit
197+
def _kv_trans_for_dp_kernel(
198+
input_mems_ptr,
199+
input_stride_0,
200+
input_stride_1,
201+
input_stride_2,
202+
input_stride_3,
203+
input_token_idx_ptr,
204+
input_token_dp_index_ptr,
205+
output_ptr,
206+
output_stride_0,
207+
output_stride_1,
208+
output_stride_2,
209+
output_stride_3,
210+
output_token_idx_ptr,
211+
layer_num: tl.constexpr,
212+
token_num: int,
213+
head_num: int,
214+
head_dim: int,
215+
grid_count: int,
216+
BLOCK_SIZE: tl.constexpr,
217+
NUM_STAGES: tl.constexpr,
218+
CARD_NUM_PER_D: tl.constexpr,
219+
RANK_IN_DP: tl.constexpr,
220+
):
221+
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
222+
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
223+
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
224+
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)
225+
226+
head_num_dim = head_num * head_dim
227+
tid = tl.program_id(0)
228+
229+
offs = tl.arange(0, BLOCK_SIZE)
230+
while tid < token_num:
231+
dp_index = tl.load(input_token_dp_index_ptr + tid)
232+
mem_index = RANK_IN_DP + dp_index * CARD_NUM_PER_D
233+
input_token_idx = tl.load(input_token_idx_ptr + tid)
234+
output_token_idx = tl.load(output_token_idx_ptr + tid)
235+
236+
input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty))
237+
for layer_idx in tl.range(0, layer_num, 1):
238+
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
239+
cur_offs = block_idx * BLOCK_SIZE + offs
240+
in_datas = tl.load(
241+
input_ptr + input_stride_0 * layer_idx + input_stride_1 * input_token_idx + cur_offs,
242+
mask=cur_offs < head_num_dim,
243+
)
244+
tl.store(
245+
output_ptr + output_stride_0 * layer_idx + output_stride_1 * output_token_idx + cur_offs,
246+
in_datas,
247+
mask=cur_offs < head_num_dim,
248+
)
249+
250+
tid += grid_count
251+
252+
return
253+
254+
255+
def kv_trans_for_dp(
256+
input_mems: torch.Tensor,
257+
input_idx: torch.Tensor,
258+
input_dp_idx: torch.Tensor,
259+
output: torch.Tensor,
260+
output_idx: torch.Tensor,
261+
dp_size_in_node: int,
262+
rank_in_dp: int,
263+
):
264+
"""
265+
input_mems 是一个 torch.uint64 的tensor, shape为(layer_num, mem_num),其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
266+
"""
267+
assert input_mems.is_contiguous()
268+
assert output.is_contiguous()
269+
assert len(input_mems.shape) == 1
270+
assert len(output.shape) == 4
271+
assert len(input_idx) == len(output_idx)
272+
assert len(output_idx) == len(input_dp_idx)
273+
assert len(input_mems) % dp_size_in_node == 0
274+
275+
card_num_per_d = len(input_mems) // dp_size_in_node
276+
277+
layer_num, _, head_num, head_dim = output.shape
278+
token_num = len(output_idx)
279+
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
280+
grid_count = 20
281+
BLOCK_SIZE = 256
282+
NUM_STAGES = 3
283+
grid = (grid_count,)
284+
285+
_kv_trans_for_dp_kernel[grid](
286+
input_mems,
287+
*output.stride(),
288+
input_idx,
289+
input_dp_idx,
290+
output,
291+
*output.stride(),
292+
output_idx,
293+
layer_num=layer_num,
294+
token_num=token_num,
295+
head_num=head_num,
296+
head_dim=head_dim,
297+
grid_count=grid_count,
298+
BLOCK_SIZE=BLOCK_SIZE,
299+
NUM_STAGES=NUM_STAGES,
300+
CARD_NUM_PER_D=card_num_per_d,
301+
RANK_IN_DP=rank_in_dp,
302+
num_warps=1,
303+
)
304+
305+
return

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,4 +566,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
566566
default=None,
567567
help="""Directory used to persist disk cache data. Defaults to a temp directory when not set.""",
568568
)
569+
parser.add_argument(
570+
"--enable_dp_prompt_cache_fetch",
571+
action="store_true",
572+
default=False,
573+
help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""",
574+
)
569575
return parser

lightllm/server/api_start.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,14 @@ def normal_or_p_d_start(args):
269269
args.router_max_wait_tokens = 0
270270

271271
send_and_receive_node_ip(args) # 多机用于收发node ip
272+
# dp 必须 > 1
273+
if args.enable_dp_prompt_cache_fetch and args.dp <= 1:
274+
args.enable_dp_prompt_cache_fetch = False
275+
logger.warning(
276+
"""dp <= 1 does not support dp_prompt_cache_fetch;
277+
overriding enable_dp_prompt_cache_fetch to False"""
278+
)
279+
272280
set_env_start_args(args)
273281
logger.info(f"all start args:{args}")
274282

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class StartArgs:
114114
enable_disk_cache: bool = field(default=False)
115115
disk_cache_storage_size: float = field(default=10)
116116
disk_cache_dir: Optional[str] = field(default=None)
117+
enable_dp_prompt_cache_fetch: bool = field(default=False)
117118
# zmp ports
118119
router_port: int = field(default=None)
119120
detokenization_port: int = field(default=None)

lightllm/server/router/manager.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ async def wait_to_model_ready(self):
116116
self.model_rpc_servers = []
117117
# 用于 kv move 管理进程 和 推理进程进行task信息的交互。
118118
self.info_queue: mp.Queue = mp.Queue()
119-
self.mem_queues: List[torch.multiprocessing.Queue] = [
120-
torch.multiprocessing.Queue() for _ in range(self.node_world_size)
121-
]
122119
self.rpc_event = multiprocessing.Event()
123120
self.rpc_finished_event = multiprocessing.Event()
124121

@@ -137,7 +134,6 @@ async def wait_to_model_ready(self):
137134
rpc_event=self.rpc_event,
138135
rpc_finished_event=self.rpc_finished_event,
139136
info_queue=self.info_queue,
140-
mem_queue=self.mem_queues[(rank_id % node_world_size)],
141137
router_lock=self.router_lock,
142138
)
143139
)
@@ -205,29 +201,29 @@ async def wait_to_model_ready(self):
205201
start_prefill_kv_move_manager_process,
206202
)
207203

208-
start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
204+
start_prefill_kv_move_manager_process(self.args, self.info_queue)
209205

210206
if self.args.run_mode == "nixl_prefill":
211207
from lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl import (
212208
start_prefill_kv_move_manager_process,
213209
)
214210

215-
start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
211+
start_prefill_kv_move_manager_process(self.args, self.info_queue)
216212

217213
if self.args.run_mode == "decode":
218214
# 启动 decode kv move 管理进程
219215
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import (
220216
start_decode_kv_move_manager_process,
221217
)
222218

223-
start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
219+
start_decode_kv_move_manager_process(self.args, self.info_queue)
224220

225221
if self.args.run_mode == "nixl_decode":
226222
from lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl import (
227223
start_decode_kv_move_manager_process,
228224
)
229225

230-
start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
226+
start_decode_kv_move_manager_process(self.args, self.info_queue)
231227

232228
return
233229

0 commit comments

Comments
 (0)