1- import copy
21import time
3- import uuid
42import uvloop
53import asyncio
64import torch
7- import rpyc
85import pickle
9- import threading
106import inspect
117
128asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
13- import concurrent .futures
149import zmq
1510import zmq .asyncio
1611import torch .multiprocessing as mp
1712import torch .distributed as dist
1813import multiprocessing
1914from typing import Dict , List , Optional
20- from .batch import Batch
15+ from .batch import Batch , Req
2116from .model_infer .model_rpc import start_model_process , ModelRpcClient
2217from .req_queue import build_req_queue
23- from lightllm .utils .infer_utils import calculate_time
2418from lightllm .server .core .objs .io_objs import GroupReqIndexes
2519from lightllm .server .core .objs import ShmReqManager , StartArgs
2620from .dynamic_prompt .radix_cache import RadixCacheReadOnlyClient
@@ -79,7 +73,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
7973 self .eos_id = args .eos_id
8074 self .has_wait_tokens = 0
8175 self .max_wait_tokens = args .router_max_wait_tokens
82- context = zmq .asyncio . Context (2 )
76+ context = zmq .Context (2 )
8377 self .recv_from_httpserver = context .socket (zmq .PULL )
8478 self .recv_from_httpserver .bind (f"{ args .zmq_mode } 127.0.0.1:{ router_port } " )
8579
@@ -106,13 +100,13 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
106100 # 主要是为了防止调度失误,造成 OOM 等错误
107101 self .router_lock = mp .Lock ()
108102 g_router_lock .obj = self .router_lock
109-
110- # 调度和推理进行折叠使用的线程池
111- self .overlap_thread_pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
112- self .schedule_task = None
113103 return
114104
115105 async def wait_to_model_ready (self ):
106+ # 调度使用的对象
107+ self .schedule_new_batch : Batch = None
108+ self .schedule_event = asyncio .Event ()
109+
116110 # 初始化模型
117111 self .model_rpc_servers = []
118112 # 用于 kv move 管理进程 和 推理进程进行task信息的交互。
@@ -140,8 +134,6 @@ async def wait_to_model_ready(self):
140134 self .model_rpc_servers .append (rpc_model )
141135
142136 self .model_rpc_client = ModelRpcClient (
143- model_infer_servers = self .model_rpc_servers ,
144- world_size = self .world_size ,
145137 rpc_event = self .rpc_event ,
146138 rpc_finished_event = self .rpc_finished_event ,
147139 )
@@ -223,7 +215,6 @@ def add_req(self, group_req_indexes: GroupReqIndexes):
223215 logger .info (f"router recive req id { req .request_id } cost time { time .time () - req .start_time } s" )
224216 self .req_queue .extend (req_group )
225217 self .send_to_detokenization .send_pyobj (group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL )
226-
227218 return
228219
229220 async def loop_for_fwd (
@@ -285,81 +276,39 @@ async def loop_for_fwd(
285276 if self .running_batch is None :
286277 await asyncio .sleep (0.01 ) # 10ms
287278
288- async def get_schedule_result (self , running_batch : Batch ):
289- if self .schedule_task is None :
290- _start_time = time .time ()
291-
292- def get_new_batch ():
293- if time .time () - _start_time < 0.001 :
294- time .sleep (0.003 )
295-
296- limit_router_queue_length = None
297- if self .is_multinode_tp :
298- # 使用 all_reduce 获取最小值
299- limit_router_queue_length = len (self .req_queue .waiting_req_list )
300- limit_router_queue_length_tensor = torch .tensor (
301- limit_router_queue_length , dtype = torch .int32 , device = "cpu"
302- )
303- dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
304- limit_router_queue_length = limit_router_queue_length_tensor .item ()
305-
306- new_batch = self .req_queue .generate_new_batch (running_batch , limit_router_queue_length )
307- return new_batch
308-
309- self .schedule_task = asyncio .get_running_loop ().run_in_executor (self .overlap_thread_pool , get_new_batch )
310- return None
311- else :
312- result = await self .schedule_task
313- self .schedule_task = None
314- return result
279+ def generate_new_batch (self ):
280+ limit_router_queue_length = None
281+ if self .is_multinode_tp :
282+ # 使用 all_reduce 获取最小值
283+ limit_router_queue_length = len (self .req_queue .waiting_req_list )
284+ limit_router_queue_length_tensor = torch .tensor (limit_router_queue_length , dtype = torch .int32 , device = "cpu" )
285+ dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
286+ limit_router_queue_length = limit_router_queue_length_tensor .item ()
287+
288+ # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
289+ new_batch = self .req_queue .generate_new_batch (
290+ Batch .merge_two_batch (self .running_batch , self .schedule_new_batch ), limit_router_queue_length
291+ )
292+ self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
293+ return
315294
316295 async def _step (self ):
317296 """
318297 事件处理循环
319298 """
320- # 删除所有已经 finished 的 req
321- # 当前无运行请求时
322- if self .running_batch is None :
323- new_batch : Batch = await self .get_schedule_result (self .running_batch )
324- if new_batch is not None :
325- self .metric_client .histogram_observe ("lightllm_batch_next_size" , len (new_batch .reqs ))
326- for req in new_batch .reqs :
327- self .metric_client .histogram_observe (
328- "lightllm_request_queue_duration_bucket" , time .time () - req .start_time
329- )
330- self .stats_tool .count_prompt_tokens (new_batch )
331- self .running_batch = new_batch
332- await self ._prefill_batch (self .running_batch )
333- self ._filter_runing_batch ()
334-
335- # 激进调度控制
336- if not self .args .disable_aggressive_schedule :
337- self .has_wait_tokens = self .max_wait_tokens
338-
339- elif self .is_multinode_and_multidp :
340- # 在多节点多 dp 的模式下,如果当前 running_batch 为None, 也需要不断的调用 decode 操作,
341- # 因为其他节点上的dp可能存在运行的请求,所以本节点也需要调用decode,推理后端的backend会
342- # padding 一些fake的请求来使推理过程可以正常完成。主要是给 deepseekv3 这种类型的大模型
343- # 使用的,其ep并行模式下需要所有节点协同。
344- await self ._decode_batch (self .running_batch )
345-
346- return
347-
348- # 有运行请求,当持续decode的次数到达一个阈值,或者有上次预调度的结果存在的时。
349- if self .has_wait_tokens >= self .max_wait_tokens or self .schedule_task is not None :
350- new_mini_batch = await self .get_schedule_result (self .running_batch )
299+ # 判断是否有新请求加入推理
300+ # 激进调度满足,有新的推理batch就需要进行加入。
301+ # 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。
302+ if (self .schedule_new_batch is not None ) and (
303+ (not self .args .disable_aggressive_schedule ) or (self .has_wait_tokens >= self .max_wait_tokens )
304+ ):
305+ new_batch = self .schedule_new_batch
306+ self .schedule_new_batch = None
307+ self ._add_new_batch_to_running_batch (new_batch = new_batch )
308+ await self ._prefill_batch (new_batch )
309+ self .stats_tool .count_prompt_tokens (new_batch )
310+ self ._filter_reqs_from_running_batch ()
351311 self .has_wait_tokens = 0
352- if new_mini_batch is not None :
353-
354- # 激进调度控制
355- if not self .args .disable_aggressive_schedule :
356- self .has_wait_tokens = self .max_wait_tokens
357-
358- self .stats_tool .count_prompt_tokens (new_mini_batch )
359- await self ._prefill_batch (new_mini_batch )
360- if not new_mini_batch .is_clear ():
361- self .running_batch .merge (new_mini_batch )
362- return
363312
364313 # Check if need pause some requests for decode.
365314 for dp_index in range (self .dp_size_in_node ):
@@ -374,51 +323,46 @@ async def _step(self):
374323
375324 # Decode
376325 self .stats_tool .count_output_tokens (self .running_batch )
377- await self ._decode_batch (self . running_batch )
378- self ._filter_runing_batch ()
326+ await self ._decode_batch ()
327+ self ._filter_reqs_from_running_batch ()
379328 self .has_wait_tokens += 1
380329 return
381330
382331 async def _prefill_batch (self , batch : Batch ):
383- start_time = time .time ()
384- self .metric_client .counter_inc ("lightllm_batch_inference_count" , "prefill" )
332+ # 添加新请求
385333 reqs = [r .to_router_rpc_obj () for r in batch .reqs ]
386334 await self .model_rpc_client .prefill (reqs )
387- batch .filter_out_finished_req (self .shm_req_manager )
388335 self ._send_detokenization_pack ()
389-
390336 logger .debug (f"Prefill Batch: { batch .simple_log ()} \n " )
391- self .metric_client .histogram_observe (
392- "lightllm_batch_inference_duration_bucket" , time .time () - start_time , "prefill"
393- )
394337 return
395338
396- async def _decode_batch (self , batch : Batch ):
397- start_time = time .time ()
398- self .metric_client .counter_inc ("lightllm_batch_inference_count" , "decode" )
339+ async def _decode_batch (self ):
340+ self .schedule_event .set ()
399341 await self .model_rpc_client .decode ()
400- # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
401- if batch is not None :
402- batch .filter_out_finished_req (self .shm_req_manager )
403-
404342 self ._send_detokenization_pack ()
405- self .metric_client .histogram_observe (
406- "lightllm_batch_inference_duration_bucket" , time .time () - start_time , "decode"
407- )
408343 return
409344
410- async def _pause_reqs (self , pasue_reqs ):
345+ async def _pause_reqs (self , pasue_reqs : List [ Req ] ):
411346 pasue_req_ids = [r .request_id for r in pasue_reqs ]
412347 await self .model_rpc_client .pause_reqs (pasue_req_ids )
413348 return
414349
415- def _filter_runing_batch (self ):
416- if self .running_batch is not None and self .running_batch .is_clear ():
417- self .running_batch = None
418- return
350+ def _add_new_batch_to_running_batch (self , new_batch : Batch ):
351+ if self .running_batch is None :
352+ self .running_batch = new_batch
353+ else :
354+ self .running_batch .merge (new_batch )
355+ return
356+
357+ def _filter_reqs_from_running_batch (self ):
358+ if self .running_batch is not None :
359+ self .running_batch .filter_out_finished_req (self .shm_req_manager )
360+ if self .running_batch .is_clear ():
361+ self .running_batch = None
362+ return
419363
420364 def _can_decode (self , batch : Batch , dp_index : int ):
421- if self .is_pd_run_mode or self .is_safe_schedule :
365+ if self .is_pd_run_mode or self .is_safe_schedule or batch is None :
422366 return True
423367 return (
424368 batch .get_batch_decode_need_tokens ()[dp_index ] + self .get_used_tokens (dp_index ) <= self .max_total_token_num
@@ -443,12 +387,35 @@ def get_used_tokens(self, dp_index):
443387 return self .max_total_token_num - self .read_only_statics_mem_manager .get_unrefed_token_num (dp_index )
444388
445389 async def loop_for_netio_req (self ):
390+ recv_max_count = 64
391+
446392 while True :
447- recv_req : GroupReqIndexes = await self .recv_from_httpserver .recv_pyobj ()
448- if isinstance (recv_req , GroupReqIndexes ):
449- self .add_req (recv_req )
450- else :
451- assert False , f"Error Req Inf { recv_req } "
393+ try :
394+ # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
395+ for _ in range (recv_max_count ):
396+ recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
397+ if isinstance (recv_req , GroupReqIndexes ):
398+ self .add_req (recv_req )
399+ else :
400+ assert False , f"Error Req Inf { recv_req } "
401+
402+ # 当队列中存在较多的请求时,将一次接受的数量上调
403+ recv_max_count = min (int (recv_max_count * 1.3 ), 256 )
404+
405+ except zmq .ZMQError :
406+ # 当队列已经开始清空的时候,将一次接受的数量下调
407+ recv_max_count = 64
408+
409+ try :
410+ await asyncio .wait_for (self .schedule_event .wait (), timeout = 0.02 )
411+ except asyncio .TimeoutError :
412+ pass
413+
414+ if self .schedule_event .is_set ():
415+ self .generate_new_batch ()
416+ self .schedule_event .clear ()
417+
418+ return
452419
453420 def clean_up (self ):
454421 return
@@ -459,6 +426,13 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
459426 graceful_registry (inspect .currentframe ().f_code .co_name )
460427 start_parent_check_thread ()
461428
429+ def handle_exception (loop , context ):
430+ logger .exception (f"Router Caught exception: { str (context )} " )
431+
432+ loop = asyncio .new_event_loop ()
433+ loop .set_exception_handler (handle_exception )
434+ asyncio .set_event_loop (loop )
435+
462436 try :
463437 router = RouterManager (
464438 args ,
@@ -467,7 +441,7 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
467441 metric_port = metric_port ,
468442 )
469443
470- asyncio . run (router .wait_to_model_ready ())
444+ loop . run_until_complete (router .wait_to_model_ready ())
471445 except :
472446 import traceback
473447 import sys
@@ -480,13 +454,6 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
480454 raise
481455
482456 pipe_writer .send ("init ok" )
483-
484- def handle_exception (loop , context ):
485- logger .exception (f"Router Caught exception: { str (context )} " )
486-
487- loop = asyncio .new_event_loop ()
488- loop .set_exception_handler (handle_exception )
489- asyncio .set_event_loop (loop )
490457 loop .create_task (router .loop_for_fwd ())
491458 loop .run_until_complete (router .loop_for_netio_req ())
492459 return
0 commit comments