1111
1212from lightllm .common .basemodel .layer_weights .hf_load_utils import load_hf_weights
1313from lightllm .common .basemodel .infer_struct import InferStateInfo
14- from lightllm .common .mem_manager import MemoryManager
14+ from lightllm .common .kv_cache_mem_manager import MemoryManager
15+ from lightllm .common .kv_cache_mem_manager .mem_utils import select_mem_manager_class
1516from lightllm .common .req_manager import ReqManager
1617from lightllm .common .infer_utils import init_req_to_token_indexes
1718from lightllm .common .build_utils import repair_config
2223from lightllm .common .basemodel .triton_kernel .gather_token_id import gather_token
2324from lightllm .utils .log_utils import init_logger
2425from lightllm .utils .dist_utils import get_dp_world_size
25- from lightllm .utils .envs_utils import get_env_start_args
26+ from lightllm .utils .envs_utils import get_env_start_args , get_llm_data_type
2627from lightllm .distributed .communication_op import dist_group_manager
2728from lightllm .common .basemodel .batch_objs import ModelInput , ModelOutput
2829from lightllm .common .triton_utils .autotuner import AutotuneLevel
@@ -68,7 +69,7 @@ def __init__(self, kvargs):
6869 self .is_token_healing = kvargs .get ("is_token_healing" , False )
6970 self .return_all_prompt_logics = kvargs .get ("return_all_prompt_logics" , False )
7071 assert not (self .is_token_healing and self .return_all_prompt_logics ), "can not be true in same time"
71- self .data_type = kvargs . get ( "data_type" , "float16" )
72+ self .data_type = get_llm_data_type ( )
7273 mtp_step = get_env_start_args ().mtp_step
7374 self .graph_max_batch_size = kvargs .get ("graph_max_batch_size" , 16 )
7475 self .graph_max_batch_size = (
@@ -89,7 +90,6 @@ def __init__(self, kvargs):
8990
9091 self .is_deepseekv3_mtp_mode = self .args .mtp_mode in ["deepseekv3_vanilla" , "deepseekv3_eagle" ]
9192
92- self ._init_datatype ()
9393 self ._init_config ()
9494 self ._verify_must ()
9595 self ._verify_params ()
@@ -180,7 +180,7 @@ def _init_weights(self):
180180
181181 def _init_mem_manager (self ):
182182 assert self .config ["num_attention_heads" ] % self .tp_world_size_ == 0
183- self .mem_manager = MemoryManager (
183+ self .mem_manager : MemoryManager = select_mem_manager_class () (
184184 self .max_total_token_num ,
185185 dtype = self .data_type ,
186186 head_num = self .config ["num_attention_heads" ] // self .tp_world_size_ ,
@@ -230,16 +230,6 @@ def _init_some_value(self):
230230 self .vocab_size = self .config ["vocab_size" ]
231231 return
232232
233- def _init_datatype (self ):
234- if self .data_type in ["fp16" , "float16" ]:
235- self .data_type = torch .float16
236- elif self .data_type in ["bf16" , "bfloat16" ]:
237- self .data_type = torch .bfloat16
238- elif self .data_type in ["fp32" , "float32" ]:
239- self .data_type = torch .float32
240- else :
241- raise ValueError (f"Unsupport datatype { self .data_type } !" )
242-
243233 def _init_cudagraph (self ):
244234 self .graph = (
245235 None if self .disable_cudagraph else CudaGraph (self .graph_max_batch_size , self .graph_max_len_in_batch )
0 commit comments