Skip to content

Commit f4193c6

Browse files
authored
Dp balancer (#991)
1 parent 5ea50a9 commit f4193c6

File tree

11 files changed

+229
-116
lines changed

11 files changed

+229
-116
lines changed

lightllm/server/api_cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
144144
using the deepseekv2 model, set dp to be equal to the tp parameter. In other cases, please
145145
do not set it and keep the default value as 1.""",
146146
)
147+
parser.add_argument(
148+
"--dp_balancer",
149+
type=str,
150+
default="bs_balancer",
151+
choices=["round_robin", "bs_balancer"],
152+
help="the dp balancer type, default is bs_balancer",
153+
)
147154
parser.add_argument(
148155
"--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len"
149156
)

lightllm/server/router/batch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ def get_req_list_for_dp(self, dp_index: int):
4040
req_list.append(req)
4141
return req_list
4242

43+
def get_all_dp_req_num(self) -> List[int]:
44+
if self.dp_size_in_node == 1:
45+
return [len(self.reqs)]
46+
47+
all_dp_req_num = [0 for _ in range(self.dp_size_in_node)]
48+
for req in self.reqs:
49+
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
50+
return all_dp_req_num
51+
4352
def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
4453
unfinished_req_ids = []
4554
for req in self.reqs:

lightllm/server/router/manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ async def wait_to_model_ready(self):
197197
return
198198

199199
def _get_schedule_time_interval(self):
200-
if self.running_batch is None:
201-
# 没有运行中的 batch 时,每 10ms 触发一次请求调度
202-
return 0.01
203-
204200
# dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求
205201
return self.schedule_time_interval
206202

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def init_model(self, kvargs):
7575
self.chunked_prefill_size = self.args.chunked_prefill_size
7676
self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs
7777
self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache
78+
self.batch_max_tokens = self.args.batch_max_tokens
7879
self.eos_id: List[int] = kvargs.get("eos_id", [2])
7980
self.disable_cudagraph = self.args.disable_cudagraph
8081
self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1
@@ -395,6 +396,7 @@ def _get_classed_reqs(
395396
# 请求,其逻辑是不适合的。
396397
pause_max_req_num = 2
397398
wait_pause_count = 0
399+
prefill_tokens = 0
398400

399401
# 因为会使用到 radix cache 和 mem_manager 的计数信息
400402
# 所以需要加锁保护。
@@ -443,7 +445,10 @@ def _get_classed_reqs(
443445
wait_pause_count += 1
444446
else:
445447
token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill)
448+
if prefill_tokens + token_num > self.batch_max_tokens:
449+
continue
446450
if token_num <= can_alloc_token_num:
451+
prefill_tokens += token_num
447452
prefill_reqs.append(req_obj)
448453
can_alloc_token_num -= token_num
449454
else:

lightllm/server/router/req_queue/base_queue.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ def __init__(self, args, router, dp_index, dp_size_in_node) -> None:
2626
self.router_token_ratio = args.router_token_ratio # ratio to determine whether the router is busy
2727
self.router_max_new_token_len = args.router_max_new_token_len
2828

29-
def append(self, req: Req):
30-
req.sample_params.suggested_dp_index = self.dp_index
31-
self.waiting_req_list.append(req)
32-
return
33-
3429
def extend(self, req_group: List[Req]):
3530
for req in req_group:
3631
req.sample_params.suggested_dp_index = self.dp_index
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .roundrobin import RoundRobinDpBalancer
2+
from typing import List
3+
from lightllm.server.router.req_queue.base_queue import BaseQueue
4+
from .bs import DpBsBalancer
5+
6+
7+
def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]):
8+
if args.dp_balancer == "round_robin":
9+
return RoundRobinDpBalancer(dp_size_in_node, inner_queues)
10+
elif args.dp_balancer == "bs_balancer":
11+
return DpBsBalancer(dp_size_in_node, inner_queues)
12+
else:
13+
raise ValueError(f"Invalid dp balancer: {args.dp_balancer}")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import random
2+
from abc import ABC, abstractmethod
3+
from typing import List, Union
4+
from lightllm.server.router.req_queue.base_queue import BaseQueue
5+
from lightllm.server.router.batch import Batch, Req
6+
from lightllm.utils.log_utils import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class DpBalancer(ABC):
12+
"""
13+
DP负载均衡器基类
14+
定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法
15+
"""
16+
17+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
18+
self.dp_size_in_node = dp_size_in_node
19+
self.inner_queues = inner_queues
20+
21+
@abstractmethod
22+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None:
23+
pass
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import random
2+
from typing import List, Union
3+
from lightllm.server.router.req_queue.base_queue import BaseQueue
4+
from lightllm.server.router.batch import Batch, Req
5+
from lightllm.utils.log_utils import init_logger
6+
from .base import DpBalancer
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class DpBsBalancer(DpBalancer):
12+
"""
13+
This balancer is main to balance the batch size of each dp rank.
14+
Because, for dp mode, if it exists a dp rank without any request, it will
15+
padding a request and cause the waste of GPU compute resource.
16+
"""
17+
18+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
19+
super().__init__(dp_size_in_node, inner_queues)
20+
21+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None:
22+
if len(reqs_waiting_for_dp_index) == 0:
23+
return
24+
# calculate the total load of each dp rank
25+
all_dp_req_num = [0 for _ in range(self.dp_size_in_node)]
26+
if current_batch is not None:
27+
all_dp_req_num = current_batch.get_all_dp_req_num()
28+
total_load_per_dp = [
29+
all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)
30+
]
31+
for req_group in reqs_waiting_for_dp_index:
32+
# find the dp rank with minimum load
33+
min_load = min(total_load_per_dp)
34+
select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load]
35+
suggested_dp_index = random.choice(select_dp_indexes)
36+
37+
# assign the request to the dp rank and update the load count
38+
for req in req_group:
39+
req.sample_params.suggested_dp_index = suggested_dp_index
40+
self.inner_queues[suggested_dp_index].extend(req_group)
41+
# update the load count for this dp rank
42+
total_load_per_dp[suggested_dp_index] += len(req_group)
43+
44+
reqs_waiting_for_dp_index.clear()
45+
return
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import random
2+
from typing import List, Union
3+
from lightllm.server.router.req_queue.base_queue import BaseQueue
4+
from lightllm.server.router.batch import Batch, Req
5+
from lightllm.utils.log_utils import init_logger
6+
from .base import DpBalancer
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class RoundRobinDpBalancer(DpBalancer):
12+
"""
13+
轮询负载均衡器
14+
在队列长度最小的DP中进行轮询选择
15+
"""
16+
17+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
18+
super().__init__(dp_size_in_node, inner_queues)
19+
self.pre_select_dp_index = self.dp_size_in_node - 1
20+
21+
def get_suggest_dp_index(
22+
self,
23+
) -> int:
24+
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
25+
select_dp_indexes = [
26+
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
27+
]
28+
29+
# 如果没有可选择的索引,随机选择一个
30+
if not select_dp_indexes:
31+
self.pre_select_dp_index = random.randint(0, self.dp_size_in_node - 1)
32+
return self.pre_select_dp_index
33+
34+
# 轮询选择
35+
for i in range(self.dp_size_in_node):
36+
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
37+
if next_dp_index in select_dp_indexes:
38+
self.pre_select_dp_index = next_dp_index
39+
return self.pre_select_dp_index
40+
41+
self.pre_select_dp_index = random.choice(select_dp_indexes)
42+
return self.pre_select_dp_index
43+
44+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None:
45+
for req_group in reqs_waiting_for_dp_index:
46+
suggested_dp_index = self.get_suggest_dp_index()
47+
for req in req_group:
48+
req.sample_params.suggested_dp_index = suggested_dp_index
49+
self.inner_queues[suggested_dp_index].extend(req_group)
50+
reqs_waiting_for_dp_index.clear()
51+
return

lightllm/server/router/req_queue/dp_base_queue.py

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List
33
from ..batch import Batch, Req
44
from lightllm.server.router.req_queue.base_queue import BaseQueue
5+
from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer
56
from lightllm.common.basemodel.infer_lock import g_router_lock
67
from lightllm.utils.log_utils import init_logger
78

@@ -12,14 +13,18 @@ class DpQueue:
1213
def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
1314
self.dp_size_in_node = dp_size_in_node
1415
self.base_queue_class = base_queue_class
15-
self.pre_select_dp_index = self.dp_size_in_node - 1
1616
from lightllm.server.router.manager import RouterManager
1717

1818
self.router: RouterManager = router
1919
self.inner_queues: List[BaseQueue] = [
2020
base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node)
2121
]
22-
22+
# 在调度这放松,在推理时约束。
23+
# 避免prefill 模式下的情况下,推理完成了,调度没及时获取信息,导致调度bs 过小
24+
for queue in self.inner_queues:
25+
queue.batch_max_tokens = int(args.batch_max_tokens * 2)
26+
self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues)
27+
self.reqs_waiting_for_dp_index: List[List[Req]] = []
2328
return
2429

2530
def get_dp_queue(self, dp_index: int):
@@ -31,6 +36,7 @@ def get_wait_req_num(self):
3136

3237
# @calculate_time(show=True, min_cost_ms=10)
3338
def generate_new_batch(self, current_batch: Batch):
39+
self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index)
3440
batches = [
3541
self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node)
3642
]
@@ -45,31 +51,13 @@ def _merge_batch(self, dp_batches: List[Batch]):
4551
merged_batch = iter_batch
4652
return merged_batch
4753

48-
def append(self, req: Req):
49-
suggested_dp_index = req.sample_params.suggested_dp_index
54+
def extend(self, req_group: List[Req]):
55+
suggested_dp_index = req_group[0].sample_params.suggested_dp_index
5056
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
51-
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
52-
suggested_dp_index = self._get_suggest_dp_index()
53-
self.pre_select_dp_index = suggested_dp_index
54-
req.sample_params.suggested_dp_index = suggested_dp_index
55-
self.inner_queues[suggested_dp_index].append(req)
57+
# 同一个组的,要分配在同一个 dp 上
58+
self.reqs_waiting_for_dp_index.append(req_group)
5659
else:
57-
self.inner_queues[suggested_dp_index].append(req)
58-
return
59-
60-
def extend(self, req_group: List[Req]):
61-
# 同一个组的,要分配在同一个 dp 上,效率最高
62-
index = self._get_suggest_dp_index()
63-
for req in req_group:
64-
suggested_dp_index = req.sample_params.suggested_dp_index
65-
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
66-
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
67-
self.pre_select_dp_index = index
68-
req.sample_params.suggested_dp_index = index
69-
self.inner_queues[index].append(req)
70-
else:
71-
self.inner_queues[suggested_dp_index].append(req)
72-
60+
self.inner_queues[suggested_dp_index].extend(req_group)
7361
return
7462

7563
def is_busy(self):
@@ -87,21 +75,3 @@ def update_token_load(self, current_batch: Batch, force_update=False):
8775
self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index)
8876
self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index)
8977
return
90-
91-
def _get_suggest_dp_index(self):
92-
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
93-
select_dp_indexes = [
94-
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
95-
]
96-
97-
# multi thread safe keep
98-
if not select_dp_indexes:
99-
return random.randint(0, self.dp_size_in_node - 1)
100-
101-
# round_robin select.
102-
for i in range(self.dp_size_in_node):
103-
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
104-
if next_dp_index in select_dp_indexes:
105-
return next_dp_index
106-
107-
return random.choice(select_dp_indexes)

0 commit comments

Comments
 (0)