Skip to content

Commit 2a538d4

Browse files
router recv reqs update (#955)
Co-authored-by: baishihao <baishihao@sensetime.com> Co-authored-by: wangzaijun <wzjhelloworld@qq.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 3eacc13 commit 2a538d4

File tree

11 files changed

+160
-217
lines changed

11 files changed

+160
-217
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ repos:
1111
hooks:
1212
- id: flake8
1313
additional_dependencies: [flake8-typing-imports==1.9.0]
14-
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606']
14+
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']

lightllm/server/router/batch.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,6 @@ def is_clear(self):
6767
return len(self.reqs) == 0
6868

6969
def merge(self, mini_batch: "Batch"):
70-
for _req in mini_batch.reqs:
71-
self.reqs.append(_req)
72-
self.id_to_reqs = {req.request_id: req for req in self.reqs}
73-
return
74-
75-
def dp_merge(self, mini_batch: "Batch"):
7670
if mini_batch is None:
7771
return
7872

@@ -81,6 +75,18 @@ def dp_merge(self, mini_batch: "Batch"):
8175
self.id_to_reqs = {req.request_id: req for req in self.reqs}
8276
return
8377

78+
@staticmethod
79+
def merge_two_batch(batch1: "Batch", batch2: "Batch") -> "Batch":
80+
if batch1 is None and batch2 is None:
81+
return None
82+
83+
not_none_batch = batch1 if batch1 is not None else batch2
84+
85+
merge_batch = Batch(-1, [], not_none_batch.dp_size_in_node)
86+
merge_batch.merge(batch1)
87+
merge_batch.merge(batch2)
88+
return merge_batch
89+
8490
def __repr__(self):
8591
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
8692

lightllm/server/router/manager.py

Lines changed: 89 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
1-
import copy
21
import time
3-
import uuid
42
import uvloop
53
import asyncio
64
import torch
7-
import rpyc
85
import pickle
9-
import threading
106
import inspect
117

128
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
13-
import concurrent.futures
149
import zmq
1510
import zmq.asyncio
1611
import torch.multiprocessing as mp
1712
import torch.distributed as dist
1813
import multiprocessing
1914
from typing import Dict, List, Optional
20-
from .batch import Batch
15+
from .batch import Batch, Req
2116
from .model_infer.model_rpc import start_model_process, ModelRpcClient
2217
from .req_queue import build_req_queue
23-
from lightllm.utils.infer_utils import calculate_time
2418
from lightllm.server.core.objs.io_objs import GroupReqIndexes
2519
from lightllm.server.core.objs import ShmReqManager, StartArgs
2620
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
@@ -79,7 +73,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
7973
self.eos_id = args.eos_id
8074
self.has_wait_tokens = 0
8175
self.max_wait_tokens = args.router_max_wait_tokens
82-
context = zmq.asyncio.Context(2)
76+
context = zmq.Context(2)
8377
self.recv_from_httpserver = context.socket(zmq.PULL)
8478
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}")
8579

@@ -106,13 +100,13 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
106100
# 主要是为了防止调度失误,造成 OOM 等错误
107101
self.router_lock = mp.Lock()
108102
g_router_lock.obj = self.router_lock
109-
110-
# 调度和推理进行折叠使用的线程池
111-
self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112-
self.schedule_task = None
113103
return
114104

115105
async def wait_to_model_ready(self):
106+
# 调度使用的对象
107+
self.schedule_new_batch: Batch = None
108+
self.schedule_event = asyncio.Event()
109+
116110
# 初始化模型
117111
self.model_rpc_servers = []
118112
# 用于 kv move 管理进程 和 推理进程进行task信息的交互。
@@ -140,8 +134,6 @@ async def wait_to_model_ready(self):
140134
self.model_rpc_servers.append(rpc_model)
141135

142136
self.model_rpc_client = ModelRpcClient(
143-
model_infer_servers=self.model_rpc_servers,
144-
world_size=self.world_size,
145137
rpc_event=self.rpc_event,
146138
rpc_finished_event=self.rpc_finished_event,
147139
)
@@ -223,7 +215,6 @@ def add_req(self, group_req_indexes: GroupReqIndexes):
223215
logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")
224216
self.req_queue.extend(req_group)
225217
self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
226-
227218
return
228219

229220
async def loop_for_fwd(
@@ -285,81 +276,39 @@ async def loop_for_fwd(
285276
if self.running_batch is None:
286277
await asyncio.sleep(0.01) # 10ms
287278

288-
async def get_schedule_result(self, running_batch: Batch):
289-
if self.schedule_task is None:
290-
_start_time = time.time()
291-
292-
def get_new_batch():
293-
if time.time() - _start_time < 0.001:
294-
time.sleep(0.003)
295-
296-
limit_router_queue_length = None
297-
if self.is_multinode_tp:
298-
# 使用 all_reduce 获取最小值
299-
limit_router_queue_length = len(self.req_queue.waiting_req_list)
300-
limit_router_queue_length_tensor = torch.tensor(
301-
limit_router_queue_length, dtype=torch.int32, device="cpu"
302-
)
303-
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
304-
limit_router_queue_length = limit_router_queue_length_tensor.item()
305-
306-
new_batch = self.req_queue.generate_new_batch(running_batch, limit_router_queue_length)
307-
return new_batch
308-
309-
self.schedule_task = asyncio.get_running_loop().run_in_executor(self.overlap_thread_pool, get_new_batch)
310-
return None
311-
else:
312-
result = await self.schedule_task
313-
self.schedule_task = None
314-
return result
279+
def generate_new_batch(self):
280+
limit_router_queue_length = None
281+
if self.is_multinode_tp:
282+
# 使用 all_reduce 获取最小值
283+
limit_router_queue_length = len(self.req_queue.waiting_req_list)
284+
limit_router_queue_length_tensor = torch.tensor(limit_router_queue_length, dtype=torch.int32, device="cpu")
285+
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
286+
limit_router_queue_length = limit_router_queue_length_tensor.item()
287+
288+
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
289+
new_batch = self.req_queue.generate_new_batch(
290+
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch), limit_router_queue_length
291+
)
292+
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
293+
return
315294

316295
async def _step(self):
317296
"""
318297
事件处理循环
319298
"""
320-
# 删除所有已经 finished 的 req
321-
# 当前无运行请求时
322-
if self.running_batch is None:
323-
new_batch: Batch = await self.get_schedule_result(self.running_batch)
324-
if new_batch is not None:
325-
self.metric_client.histogram_observe("lightllm_batch_next_size", len(new_batch.reqs))
326-
for req in new_batch.reqs:
327-
self.metric_client.histogram_observe(
328-
"lightllm_request_queue_duration_bucket", time.time() - req.start_time
329-
)
330-
self.stats_tool.count_prompt_tokens(new_batch)
331-
self.running_batch = new_batch
332-
await self._prefill_batch(self.running_batch)
333-
self._filter_runing_batch()
334-
335-
# 激进调度控制
336-
if not self.args.disable_aggressive_schedule:
337-
self.has_wait_tokens = self.max_wait_tokens
338-
339-
elif self.is_multinode_and_multidp:
340-
# 在多节点多 dp 的模式下,如果当前 running_batch 为None, 也需要不断的调用 decode 操作,
341-
# 因为其他节点上的dp可能存在运行的请求,所以本节点也需要调用decode,推理后端的backend会
342-
# padding 一些fake的请求来使推理过程可以正常完成。主要是给 deepseekv3 这种类型的大模型
343-
# 使用的,其ep并行模式下需要所有节点协同。
344-
await self._decode_batch(self.running_batch)
345-
346-
return
347-
348-
# 有运行请求,当持续decode的次数到达一个阈值,或者有上次预调度的结果存在的时。
349-
if self.has_wait_tokens >= self.max_wait_tokens or self.schedule_task is not None:
350-
new_mini_batch = await self.get_schedule_result(self.running_batch)
299+
# 判断是否有新请求加入推理
300+
# 激进调度满足,有新的推理batch就需要进行加入。
301+
# 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。
302+
if (self.schedule_new_batch is not None) and (
303+
(not self.args.disable_aggressive_schedule) or (self.has_wait_tokens >= self.max_wait_tokens)
304+
):
305+
new_batch = self.schedule_new_batch
306+
self.schedule_new_batch = None
307+
self._add_new_batch_to_running_batch(new_batch=new_batch)
308+
await self._prefill_batch(new_batch)
309+
self.stats_tool.count_prompt_tokens(new_batch)
310+
self._filter_reqs_from_running_batch()
351311
self.has_wait_tokens = 0
352-
if new_mini_batch is not None:
353-
354-
# 激进调度控制
355-
if not self.args.disable_aggressive_schedule:
356-
self.has_wait_tokens = self.max_wait_tokens
357-
358-
self.stats_tool.count_prompt_tokens(new_mini_batch)
359-
await self._prefill_batch(new_mini_batch)
360-
if not new_mini_batch.is_clear():
361-
self.running_batch.merge(new_mini_batch)
362-
return
363312

364313
# Check if need pause some requests for decode.
365314
for dp_index in range(self.dp_size_in_node):
@@ -374,51 +323,46 @@ async def _step(self):
374323

375324
# Decode
376325
self.stats_tool.count_output_tokens(self.running_batch)
377-
await self._decode_batch(self.running_batch)
378-
self._filter_runing_batch()
326+
await self._decode_batch()
327+
self._filter_reqs_from_running_batch()
379328
self.has_wait_tokens += 1
380329
return
381330

382331
async def _prefill_batch(self, batch: Batch):
383-
start_time = time.time()
384-
self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill")
332+
# 添加新请求
385333
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
386334
await self.model_rpc_client.prefill(reqs)
387-
batch.filter_out_finished_req(self.shm_req_manager)
388335
self._send_detokenization_pack()
389-
390336
logger.debug(f"Prefill Batch: {batch.simple_log()} \n")
391-
self.metric_client.histogram_observe(
392-
"lightllm_batch_inference_duration_bucket", time.time() - start_time, "prefill"
393-
)
394337
return
395338

396-
async def _decode_batch(self, batch: Batch):
397-
start_time = time.time()
398-
self.metric_client.counter_inc("lightllm_batch_inference_count", "decode")
339+
async def _decode_batch(self):
340+
self.schedule_event.set()
399341
await self.model_rpc_client.decode()
400-
# 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
401-
if batch is not None:
402-
batch.filter_out_finished_req(self.shm_req_manager)
403-
404342
self._send_detokenization_pack()
405-
self.metric_client.histogram_observe(
406-
"lightllm_batch_inference_duration_bucket", time.time() - start_time, "decode"
407-
)
408343
return
409344

410-
async def _pause_reqs(self, pasue_reqs):
345+
async def _pause_reqs(self, pasue_reqs: List[Req]):
411346
pasue_req_ids = [r.request_id for r in pasue_reqs]
412347
await self.model_rpc_client.pause_reqs(pasue_req_ids)
413348
return
414349

415-
def _filter_runing_batch(self):
416-
if self.running_batch is not None and self.running_batch.is_clear():
417-
self.running_batch = None
418-
return
350+
def _add_new_batch_to_running_batch(self, new_batch: Batch):
351+
if self.running_batch is None:
352+
self.running_batch = new_batch
353+
else:
354+
self.running_batch.merge(new_batch)
355+
return
356+
357+
def _filter_reqs_from_running_batch(self):
358+
if self.running_batch is not None:
359+
self.running_batch.filter_out_finished_req(self.shm_req_manager)
360+
if self.running_batch.is_clear():
361+
self.running_batch = None
362+
return
419363

420364
def _can_decode(self, batch: Batch, dp_index: int):
421-
if self.is_pd_run_mode or self.is_safe_schedule:
365+
if self.is_pd_run_mode or self.is_safe_schedule or batch is None:
422366
return True
423367
return (
424368
batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num
@@ -443,12 +387,35 @@ def get_used_tokens(self, dp_index):
443387
return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
444388

445389
async def loop_for_netio_req(self):
390+
recv_max_count = 64
391+
446392
while True:
447-
recv_req: GroupReqIndexes = await self.recv_from_httpserver.recv_pyobj()
448-
if isinstance(recv_req, GroupReqIndexes):
449-
self.add_req(recv_req)
450-
else:
451-
assert False, f"Error Req Inf {recv_req}"
393+
try:
394+
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
395+
for _ in range(recv_max_count):
396+
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
397+
if isinstance(recv_req, GroupReqIndexes):
398+
self.add_req(recv_req)
399+
else:
400+
assert False, f"Error Req Inf {recv_req}"
401+
402+
# 当队列中存在较多的请求时,将一次接受的数量上调
403+
recv_max_count = min(int(recv_max_count * 1.3), 256)
404+
405+
except zmq.ZMQError:
406+
# 当队列已经开始清空的时候,将一次接受的数量下调
407+
recv_max_count = 64
408+
409+
try:
410+
await asyncio.wait_for(self.schedule_event.wait(), timeout=0.02)
411+
except asyncio.TimeoutError:
412+
pass
413+
414+
if self.schedule_event.is_set():
415+
self.generate_new_batch()
416+
self.schedule_event.clear()
417+
418+
return
452419

453420
def clean_up(self):
454421
return
@@ -459,6 +426,13 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
459426
graceful_registry(inspect.currentframe().f_code.co_name)
460427
start_parent_check_thread()
461428

429+
def handle_exception(loop, context):
430+
logger.exception(f"Router Caught exception: {str(context)}")
431+
432+
loop = asyncio.new_event_loop()
433+
loop.set_exception_handler(handle_exception)
434+
asyncio.set_event_loop(loop)
435+
462436
try:
463437
router = RouterManager(
464438
args,
@@ -467,7 +441,7 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
467441
metric_port=metric_port,
468442
)
469443

470-
asyncio.run(router.wait_to_model_ready())
444+
loop.run_until_complete(router.wait_to_model_ready())
471445
except:
472446
import traceback
473447
import sys
@@ -480,13 +454,6 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
480454
raise
481455

482456
pipe_writer.send("init ok")
483-
484-
def handle_exception(loop, context):
485-
logger.exception(f"Router Caught exception: {str(context)}")
486-
487-
loop = asyncio.new_event_loop()
488-
loop.set_exception_handler(handle_exception)
489-
asyncio.set_event_loop(loop)
490457
loop.create_task(router.loop_for_fwd())
491458
loop.run_until_complete(router.loop_for_netio_req())
492459
return

0 commit comments

Comments
 (0)