22from typing import List
33from ..batch import Batch , Req
44from lightllm .server .router .req_queue .base_queue import BaseQueue
5+ from lightllm .server .router .req_queue .dp_balancer import get_dp_balancer
56from lightllm .common .basemodel .infer_lock import g_router_lock
67from lightllm .utils .log_utils import init_logger
78
@@ -12,14 +13,18 @@ class DpQueue:
1213 def __init__ (self , args , router , base_queue_class , dp_size_in_node ) -> None :
1314 self .dp_size_in_node = dp_size_in_node
1415 self .base_queue_class = base_queue_class
15- self .pre_select_dp_index = self .dp_size_in_node - 1
1616 from lightllm .server .router .manager import RouterManager
1717
1818 self .router : RouterManager = router
1919 self .inner_queues : List [BaseQueue ] = [
2020 base_queue_class (args , router , dp_index , dp_size_in_node ) for dp_index in range (self .dp_size_in_node )
2121 ]
22-
22+ # 在调度这放松,在推理时约束。
23+ # 避免prefill 模式下的情况下,推理完成了,调度没及时获取信息,导致调度bs 过小
24+ for queue in self .inner_queues :
25+ queue .batch_max_tokens = int (args .batch_max_tokens * 2 )
26+ self .dp_balancer = get_dp_balancer (args , dp_size_in_node , self .inner_queues )
27+ self .reqs_waiting_for_dp_index : List [List [Req ]] = []
2328 return
2429
2530 def get_dp_queue (self , dp_index : int ):
@@ -31,6 +36,7 @@ def get_wait_req_num(self):
3136
3237 # @calculate_time(show=True, min_cost_ms=10)
3338 def generate_new_batch (self , current_batch : Batch ):
39+ self .dp_balancer .assign_reqs_to_dp (current_batch , self .reqs_waiting_for_dp_index )
3440 batches = [
3541 self .inner_queues [dp_index ].generate_new_batch (current_batch ) for dp_index in range (self .dp_size_in_node )
3642 ]
@@ -45,31 +51,13 @@ def _merge_batch(self, dp_batches: List[Batch]):
4551 merged_batch = iter_batch
4652 return merged_batch
4753
48- def append (self , req : Req ):
49- suggested_dp_index = req .sample_params .suggested_dp_index
54+ def extend (self , req_group : List [ Req ] ):
55+ suggested_dp_index = req_group [ 0 ] .sample_params .suggested_dp_index
5056 if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
51- logger .warning (f"input req { req .request_id } dp index { suggested_dp_index } is invalid" )
52- suggested_dp_index = self ._get_suggest_dp_index ()
53- self .pre_select_dp_index = suggested_dp_index
54- req .sample_params .suggested_dp_index = suggested_dp_index
55- self .inner_queues [suggested_dp_index ].append (req )
57+ # 同一个组的,要分配在同一个 dp 上
58+ self .reqs_waiting_for_dp_index .append (req_group )
5659 else :
57- self .inner_queues [suggested_dp_index ].append (req )
58- return
59-
60- def extend (self , req_group : List [Req ]):
61- # 同一个组的,要分配在同一个 dp 上,效率最高
62- index = self ._get_suggest_dp_index ()
63- for req in req_group :
64- suggested_dp_index = req .sample_params .suggested_dp_index
65- if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
66- logger .warning (f"input req { req .request_id } dp index { suggested_dp_index } is invalid" )
67- self .pre_select_dp_index = index
68- req .sample_params .suggested_dp_index = index
69- self .inner_queues [index ].append (req )
70- else :
71- self .inner_queues [suggested_dp_index ].append (req )
72-
60+ self .inner_queues [suggested_dp_index ].extend (req_group )
7361 return
7462
7563 def is_busy (self ):
@@ -87,21 +75,3 @@ def update_token_load(self, current_batch: Batch, force_update=False):
8775 self .router .shared_token_load .set_estimated_peak_token_count (estimated_peak_token_count , dp_index )
8876 self .router .shared_token_load .set_dynamic_max_load (dynamic_max_load , dp_index )
8977 return
90-
91- def _get_suggest_dp_index (self ):
92- min_length = min (len (queue .waiting_req_list ) for queue in self .inner_queues )
93- select_dp_indexes = [
94- i for i , queue in enumerate (self .inner_queues ) if len (queue .waiting_req_list ) == min_length
95- ]
96-
97- # multi thread safe keep
98- if not select_dp_indexes :
99- return random .randint (0 , self .dp_size_in_node - 1 )
100-
101- # round_robin select.
102- for i in range (self .dp_size_in_node ):
103- next_dp_index = (self .pre_select_dp_index + i + 1 ) % self .dp_size_in_node
104- if next_dp_index in select_dp_indexes :
105- return next_dp_index
106-
107- return random .choice (select_dp_indexes )
0 commit comments