Skip to content

Commit a41492e

Browse files
hiworldwzjwangzaijun
andauthored
cpu kv cache support quanted kv. (#1133)
Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
1 parent 160c1b0 commit a41492e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+504
-105
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
1313
from lightllm.common.basemodel.infer_struct import InferStateInfo
14-
from lightllm.common.mem_manager import MemoryManager
14+
from lightllm.common.kv_cache_mem_manager import MemoryManager
15+
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
1516
from lightllm.common.req_manager import ReqManager
1617
from lightllm.common.infer_utils import init_req_to_token_indexes
1718
from lightllm.common.build_utils import repair_config
@@ -22,7 +23,7 @@
2223
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
2324
from lightllm.utils.log_utils import init_logger
2425
from lightllm.utils.dist_utils import get_dp_world_size
25-
from lightllm.utils.envs_utils import get_env_start_args
26+
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type
2627
from lightllm.distributed.communication_op import dist_group_manager
2728
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2829
from lightllm.common.triton_utils.autotuner import AutotuneLevel
@@ -68,7 +69,7 @@ def __init__(self, kvargs):
6869
self.is_token_healing = kvargs.get("is_token_healing", False)
6970
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
7071
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
71-
self.data_type = kvargs.get("data_type", "float16")
72+
self.data_type = get_llm_data_type()
7273
mtp_step = get_env_start_args().mtp_step
7374
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
7475
self.graph_max_batch_size = (
@@ -89,7 +90,6 @@ def __init__(self, kvargs):
8990

9091
self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]
9192

92-
self._init_datatype()
9393
self._init_config()
9494
self._verify_must()
9595
self._verify_params()
@@ -180,7 +180,7 @@ def _init_weights(self):
180180

181181
def _init_mem_manager(self):
182182
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
183-
self.mem_manager = MemoryManager(
183+
self.mem_manager: MemoryManager = select_mem_manager_class()(
184184
self.max_total_token_num,
185185
dtype=self.data_type,
186186
head_num=self.config["num_attention_heads"] // self.tp_world_size_,
@@ -230,16 +230,6 @@ def _init_some_value(self):
230230
self.vocab_size = self.config["vocab_size"]
231231
return
232232

233-
def _init_datatype(self):
234-
if self.data_type in ["fp16", "float16"]:
235-
self.data_type = torch.float16
236-
elif self.data_type in ["bf16", "bfloat16"]:
237-
self.data_type = torch.bfloat16
238-
elif self.data_type in ["fp32", "float32"]:
239-
self.data_type = torch.float32
240-
else:
241-
raise ValueError(f"Unsupport datatype {self.data_type}!")
242-
243233
def _init_cudagraph(self):
244234
self.graph = (
245235
None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch)

lightllm/common/basemodel/infer_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import triton
33
import collections
4-
from lightllm.common.mem_manager import MemoryManager
4+
from lightllm.common.kv_cache_mem_manager import MemoryManager
55
from lightllm.common.req_manager import ReqManager
66
from lightllm.distributed import CustomProcessGroup
77
from typing import Tuple, Any, Optional, List

0 commit comments

Comments
 (0)