Skip to content

Commit bb9d6fc

Browse files
author
niushengxiao
committed
feat: support page size variable for deepseek2
1 parent 1a6af0a commit bb9d6fc

File tree

7 files changed

+147
-28
lines changed

7 files changed

+147
-28
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
import numpy as np
3+
from .deepseek2_mem_manager import Deepseek2MemoryManager
4+
from .page_size_variable_mem_manager import PageSizeVariableMemoryManager
5+
from lightllm.utils.log_utils import init_logger
6+
from lightllm.utils.envs_utils import get_page_size
7+
8+
9+
def cdiv(a, b):
10+
return (a + b - 1) // b
11+
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class Deepseek2PageSizeVariableMemoryManager(PageSizeVariableMemoryManager, Deepseek2MemoryManager):
17+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
18+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
19+
20+
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
21+
self.kv_buffer = torch.empty(
22+
(layer_num, cdiv(size, get_page_size()) * get_page_size(), head_num, head_dim),
23+
dtype=dtype,
24+
device="cuda",
25+
)

lightllm/models/deepseek2/flashattention_infer_struct.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import torch.distributed as dist
55
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
66
from lightllm.utils.dist_utils import get_current_device_id
7+
from lightllm.utils.envs_utils import get_page_size
8+
9+
10+
def cdiv(a, b):
11+
return (a + b - 1) // b
712

813

914
class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo):
@@ -38,20 +43,24 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3843
self.cu_seqlens_q = self.b1_cu_q_seq_len
3944
self.cu_seqlens_k = self.b1_cu_kv_seq_len
4045
max_seq_len_k = self.max_kv_seq_len
46+
page_size = get_page_size()
4147
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
42-
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(
43-
model.graph_max_batch_size, model.graph_max_len_in_batch
48+
length = cdiv(model.graph_max_len_in_batch, page_size)
49+
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
50+
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
51+
self.batch_size, length
4452
)
45-
self.page_table = page_buffer[self.microbatch_index][
46-
: self.batch_size * model.graph_max_len_in_batch
47-
].reshape(self.batch_size, model.graph_max_len_in_batch)
4853
else:
49-
self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(
50-
input_ids.device
51-
)
54+
length = cdiv(self.max_len_in_batch, page_size)
55+
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device)
5256

53-
self.page_table[:, :max_seq_len_k].copy_(
54-
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
55-
)
56-
self.page_table[:, max_seq_len_k:].fill_(0)
57+
if "page_size_variable" in model.mode:
58+
length = cdiv(max_seq_len_k, page_size)
59+
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
60+
self.page_table[:, length:].fill_(0)
61+
else:
62+
self.page_table[:, :max_seq_len_k].copy_(
63+
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
64+
)
65+
self.page_table[:, max_seq_len_k:].fill_(0)
5766
return

lightllm/models/deepseek2/flashinfer_struct.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,21 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
6-
from lightllm.utils.envs_utils import get_env_start_args
6+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
77
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
88

99

10+
def cdiv(a, b):
11+
return (a + b - 1) // b
12+
13+
1014
class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo):
1115
def __init__(self):
1216
super().__init__()
1317
self.prefill_wrapper = None
1418
self.decode_wrapper = None
1519
self.flashinfer_extra_state = None
20+
self.page_size = get_page_size()
1621

1722
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1823
super().init_some_extra_state(model, input_ids)
@@ -23,24 +28,37 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2328
if not self.is_prefill:
2429
if get_env_start_args().enable_flashinfer_decode:
2530
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
31+
length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size)
2632
if self.batch_size <= model.graph_max_batch_size:
2733
self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][
28-
: self.batch_size * self.flashinfer_extra_state.max_seq_length
34+
: self.batch_size * length
2935
]
3036
else:
3137
self.kv_indices = torch.empty(
32-
self.batch_size * self.flashinfer_extra_state.max_seq_length,
38+
self.batch_size * length,
3339
dtype=torch.int32,
3440
device=input_ids.device,
3541
)
36-
repack_kv_index(
37-
self.req_manager.req_to_token_indexs,
38-
self.b_req_idx,
39-
self.b_seq_len,
40-
self.b_start_loc,
41-
self.max_len_in_batch,
42-
self.kv_indices,
43-
)
42+
if "page_size_variable" in model.mode:
43+
b_page_len = cdiv(self.b_seq_len, self.page_size)
44+
self.kv_starts[1:] = b_page_len.cumsum(0)
45+
repack_kv_index(
46+
self.req_manager.req_to_page_indexs,
47+
self.b_req_idx,
48+
b_page_len,
49+
self.kv_starts[:-1],
50+
cdiv(self.max_len_in_batch, self.page_size),
51+
self.kv_indices,
52+
)
53+
else:
54+
repack_kv_index(
55+
self.req_manager.req_to_token_indexs,
56+
self.b_req_idx,
57+
self.b_seq_len,
58+
self.b_start_loc,
59+
self.max_len_in_batch,
60+
self.kv_indices,
61+
)
4462
if self.decode_wrapper is None:
4563
self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
4664
self.flashinfer_extra_state.workspace_buffer,
@@ -58,7 +76,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
5876
self.flashinfer_extra_state.tp_q_head_num,
5977
self.flashinfer_extra_state.kv_lora_rank,
6078
self.flashinfer_extra_state.qk_rope_head_dim,
61-
1,
79+
self.page_size,
6280
False, # causal
6381
self.flashinfer_extra_state.softmax_scale,
6482
self.flashinfer_extra_state.q_data_type,
@@ -97,7 +115,7 @@ def copy_for_cuda_graph(self, new_infer_state):
97115
new_infer_state.flashinfer_extra_state.tp_q_head_num,
98116
new_infer_state.flashinfer_extra_state.kv_lora_rank,
99117
new_infer_state.flashinfer_extra_state.qk_rope_head_dim,
100-
1,
118+
self.page_size,
101119
False, # causal
102120
new_infer_state.flashinfer_extra_state.softmax_scale,
103121
new_infer_state.flashinfer_extra_state.q_data_type,

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import partial
2727
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2828
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
29-
from lightllm.utils.envs_utils import get_env_start_args
29+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
3030
from lightllm.utils.dist_utils import get_global_world_size
3131
from lightllm.utils.log_utils import init_logger
3232
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
@@ -94,6 +94,18 @@ def _bind_attention(self):
9494
self._token_attention_kernel = partial(
9595
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self
9696
)
97+
elif "page_size_variable" in self.mode:
98+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
99+
if get_env_start_args().enable_fa3:
100+
self._token_attention_kernel = partial(
101+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention_paged, self
102+
)
103+
elif get_env_start_args().enable_flashinfer_decode:
104+
self._token_attention_kernel = partial(
105+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer_paged, self
106+
)
107+
else:
108+
raise Exception("Page size variable mode is not supported in other backends.")
97109
else:
98110
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
99111
if get_env_start_args().enable_fa3:
@@ -575,6 +587,36 @@ def _token_gqa_decode_attention_flashattention(
575587
)
576588
return o_tensor
577589

590+
def _token_gqa_decode_attention_flashattention_paged(
591+
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
592+
):
593+
page_size = get_page_size()
594+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
595+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
596+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
597+
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim)
598+
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank)
599+
k_descale, v_descale = None, None
600+
o_tensor = flash_attn_with_kvcache(
601+
q=q_rope,
602+
k_cache=k_rope,
603+
v_cache=kv_nope,
604+
qv=q_nope,
605+
page_table=infer_state.page_table,
606+
cache_seqlens=infer_state.b_seq_len,
607+
cu_seqlens_q=infer_state.cu_seqlens_q,
608+
cu_seqlens_k_new=infer_state.cu_seqlens_k,
609+
max_seqlen_q=1,
610+
softmax_scale=self.softmax_scale,
611+
causal=True,
612+
window_size=(-1, -1),
613+
softcap=0.0,
614+
k_descale=k_descale,
615+
v_descale=v_descale,
616+
return_softmax_lse=False,
617+
)
618+
return o_tensor
619+
578620
def _token_gqa_decode_attention_flashinfer(
579621
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
580622
):
@@ -594,6 +636,26 @@ def _token_gqa_decode_attention_flashinfer(
594636
)
595637
return o_tensor
596638

639+
def _token_gqa_decode_attention_flashinfer_paged(
640+
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
641+
):
642+
page_size = get_page_size()
643+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
644+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
645+
646+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
647+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)
648+
649+
infer_state.decode_wrapper.run(
650+
q_nope,
651+
q_rope,
652+
kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank),
653+
kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim),
654+
out=o_tensor,
655+
return_lse=False,
656+
)
657+
return o_tensor
658+
597659
def _token_gqa_decode_attention_flashdecoding(
598660
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
599661
):

lightllm/models/deepseek2/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from lightllm.models.llama.model import LlamaTpPartModel
1212
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
13+
from lightllm.common.deepseek2_page_size_variable_mem_manager import Deepseek2PageSizeVariableMemoryManager
1314
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1415
from lightllm.utils.log_utils import init_logger
1516
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
@@ -97,6 +98,10 @@ def _init_mem_manager(self):
9798
manager_class = Deepseek2MemoryManager
9899
if "triton_fp8kv" in self.mode:
99100
manager_class = Deepseek2FP8KVMemoryManager
101+
elif "page_size_variable" in self.mode:
102+
manager_class = Deepseek2PageSizeVariableMemoryManager
103+
elif self.mode:
104+
raise ValueError(f"Unsupported mode for deepseek2: {self.mode}")
100105

101106
# mtp 模式下需要在mem manger上扩展draft model使用的layer
102107
added_mtp_layer_num = 0

lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _fwd_kernel_destindex_copy_kv(
3434
offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE)
3535
offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)
3636

37-
dest_index = tl.load(Dest_loc + cur_index)
37+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
3838

3939
kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :]
4040
kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :]

lightllm/models/deepseek2/triton_kernel/sample_kv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _sample_kv_kernel(
4444
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m,
4545
mask=offs_m < block_end_loc,
4646
other=0,
47-
)
47+
).to(tl.int64)
4848
off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :]
4949
off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :]
5050
kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0)

0 commit comments

Comments
 (0)