Skip to content

Commit cd3b691

Browse files
committed
feat: Implementing Past Future Scheduler
1 parent d6d59ec commit cd3b691

File tree

7 files changed

+114
-13
lines changed

7 files changed

+114
-13
lines changed

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
228228
parser.add_argument(
229229
"--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router"
230230
)
231+
parser.add_argument(
232+
"--past_future_scheduler",
233+
action="store_true",
234+
help="""use past_future_scheduler for adaptive request new token len prediction,
235+
override --router_token_ratio and --router_max_new_token_len (still used during warmup)""",
236+
)
231237

232238
parser.add_argument(
233239
"--router_max_wait_tokens",

lightllm/server/core/objs/req.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def print_time_log(self, log_info: str):
311311
class ChunkedPrefillReq(Req):
312312
_pack_ = 4
313313

314-
def get_tuple_tokens(self, is_busy, router_max_new_token_len):
314+
def get_tuple_tokens(self, is_busy, router_max_new_token_len, has_out_len_factor=1.1):
315315
args = get_env_start_args()
316316
# chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于
317317
# 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存
@@ -328,7 +328,7 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len):
328328
cur_max_new_token_len = self.sample_params.max_new_tokens
329329
else:
330330
cur_max_new_token_len = min(
331-
self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), router_max_new_token_len)
331+
self.sample_params.max_new_tokens, max(int(has_out_len_factor * has_out_len), router_max_new_token_len)
332332
)
333333

334334
a_len = max(self.input_len + has_out_len + 1, self.shm_cur_kv_len + 1)

lightllm/server/httpserver/manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -676,9 +676,9 @@ async def recycle_resource_loop(self):
676676
continue
677677

678678
logger.info(
679-
f"left req id {req_status.group_req_objs.group_req_id}"
680-
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
681-
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
679+
f"left req id: {req_status.group_req_objs.group_req_id}, "
680+
f"can release: {req_status.group_req_objs.shm_req_objs[0].can_released_mark}, "
681+
f"refcount: {req_status.group_req_objs.shm_req_objs[0].ref_count}"
682682
)
683683
return
684684

lightllm/server/router/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
2828
from lightllm.utils.log_utils import init_logger, log_time_ready
2929
from lightllm.server.router.token_load import TokenLoad
30+
from lightllm.server.router.req_queue.chunked_prefill.impl_past_future import PastFutureQueue
3031
from lightllm.server.metrics.manager import MetricClient
3132
from lightllm.common.basemodel.infer_lock import g_router_lock
3233
from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager
@@ -353,6 +354,8 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):
353354

354355
def _filter_reqs_from_running_batch(self):
355356
if self.running_batch is not None:
357+
if isinstance(self.req_queue, PastFutureQueue):
358+
self.req_queue.record_finished_len_from_batch(self.running_batch)
356359
self.running_batch.filter_out_finished_req(self.shm_req_manager)
357360
if self.running_batch.is_clear():
358361
self.running_batch = None

lightllm/server/router/req_queue/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,39 @@
22
from .chunked_prefill.impl import ChunkedPrefillQueue
33
from .chunked_prefill.beam_impl import ChunkedBeamContinuesBatchQueue
44
from .chunked_prefill.impl_for_nixl_pd import NIXLPDQueue
5+
from .chunked_prefill.impl_past_future import PastFutureQueue
56
from .dp_base_queue import DpQueue
67

78

89
def _get_req_queue_class(args, router, dp_size_in_node: int):
10+
if args.past_future_scheduler:
11+
if args.diverse_mode:
12+
raise ValueError("Diverse mode is not supported with past future scheduler yet")
13+
chunked_prefill_queue_impl = PastFutureQueue
14+
else:
15+
chunked_prefill_queue_impl = ChunkedPrefillQueue
16+
917
if args.diverse_mode:
1018
return ChunkedBeamContinuesBatchQueue
1119
if args.token_healing_mode:
12-
return ChunkedPrefillQueue
20+
return chunked_prefill_queue_impl
1321
if args.output_constraint_mode != "none":
14-
return ChunkedPrefillQueue
22+
return chunked_prefill_queue_impl
1523
if args.first_token_constraint_mode:
16-
return ChunkedPrefillQueue
24+
return chunked_prefill_queue_impl
1725
if args.run_mode in ["decode"]:
1826
return QueueForPDDecode
1927
if args.run_mode in ["prefill"]:
20-
return ChunkedPrefillQueue
28+
return chunked_prefill_queue_impl
2129
if args.run_mode in ["nixl_prefill", "nixl_decode"]:
2230
return NIXLPDQueue
2331

2432
if args.disable_chunked_prefill:
2533
# 虽然也使用chuncked prefill queue 但是由于 args.chunked_prefill_size = args.max_req_total_len
2634
# 所以调度的实际行为类似过去的 continues batch 调度,所以将两种调度的实现统一为一种实现,减少代码重复。
27-
return ChunkedPrefillQueue
35+
return chunked_prefill_queue_impl
2836
else:
29-
return ChunkedPrefillQueue
37+
return chunked_prefill_queue_impl
3038

3139

3240
def build_req_queue(args, router, dp_size_in_node: int):

lightllm/server/router/req_queue/chunked_prefill/impl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def _init_cache_list(self, current_batch: Batch, is_busy):
2121
self.cache_len_list = []
2222
return
2323

24-
# @calculate_time(show=True, min_cost_ms=0.1)
25-
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
24+
def _update_cache_len_list(self, req: Req, is_busy):
2625
self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis
2726
self.cache_len_list.sort(key=lambda x: -x[1])
2827

@@ -32,6 +31,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens
3231
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
3332

3433
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
34+
return need_max_token_num
35+
36+
# @calculate_time(show=True, min_cost_ms=0.1)
37+
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
38+
need_max_token_num = self._update_cache_len_list(req, is_busy)
3539
with g_router_lock.obj:
3640
ok_token_num = (
3741
need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import bisect
2+
from collections import deque
3+
import random
4+
from typing import List, Tuple
5+
import numpy as np
6+
from ...batch import Batch, Req
7+
from .impl import ChunkedPrefillQueue
8+
9+
10+
class PastFutureQueue(ChunkedPrefillQueue):
11+
WINDOW_SIZE = 200
12+
MINIMUM_SAMPLES = 200
13+
MAXIMUM_LISTS = 5
14+
REVERSED = 0.05
15+
COMPLIANCE_IS_BUSY_FLAG = False
16+
17+
def __init__(self, args, router, dp_index, dp_size_in_node) -> None:
18+
super().__init__(args, router, dp_index, dp_size_in_node)
19+
initial_len = args.router_max_new_token_len
20+
self.history_output_len = deque([initial_len] * (self.WINDOW_SIZE // 2), maxlen=self.WINDOW_SIZE)
21+
22+
def _sample_cache_list(self, reqs: List[Req], is_busy, samples=1) -> List[List[Tuple[int, int]]]:
23+
cache_len_lists = [[] for _ in range(samples)]
24+
his_Lo = sorted(self.history_output_len)
25+
for req in reqs:
26+
dl = req.shm_cur_output_len
27+
pos = bisect.bisect(his_Lo, dl)
28+
29+
sample_range = [dl] + his_Lo[pos:] + [req.sample_params.max_new_tokens] # at least 2 value
30+
31+
for i in range(samples):
32+
random_p = np.random.random() * (len(sample_range) - 1)
33+
l_pos = int(random_p)
34+
l_val, r_val = sample_range[l_pos : l_pos + 2]
35+
36+
# Linear interpolation
37+
sampled = round(l_val + (r_val - l_val) * (random_p - l_pos))
38+
cache_len_lists[i].append(
39+
req.get_tuple_tokens(is_busy and self.COMPLIANCE_IS_BUSY_FLAG, sampled, has_out_len_factor=1.0)
40+
)
41+
42+
return cache_len_lists
43+
44+
def _calc_max_token_num_needed(self, cache_len_list: List[Tuple[int, int]]) -> int:
45+
cache_len_list.sort(key=lambda x: -x[1])
46+
47+
left_out_len_array = np.array([e[1] for e in cache_len_list])
48+
has_run_len_array = np.array([e[0] for e in cache_len_list])
49+
cum_run_len_array = np.cumsum(has_run_len_array)
50+
size_array = np.arange(1, len(cache_len_list) + 1, 1)
51+
52+
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
53+
return need_max_token_num
54+
55+
def _init_cache_list(self, current_batch: Batch, is_busy):
56+
if current_batch is not None:
57+
n_lists = min(self.MAXIMUM_LISTS, int(self.MINIMUM_SAMPLES / len(current_batch.reqs)) + 1)
58+
local_reqs = [req for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index]
59+
self._cache_len_lists = self._sample_cache_list(local_reqs, is_busy, samples=n_lists)
60+
else:
61+
self._cache_len_lists = [[]]
62+
self.cache_len_list = self._cache_len_lists[0] # keep compatibility
63+
64+
def _update_cache_len_list(self, req: Req, is_busy):
65+
need_max_token_nums = []
66+
for li in self._cache_len_lists:
67+
newreq_output_len_sample = random.choice(self.history_output_len)
68+
li.append(
69+
req.get_tuple_tokens(
70+
is_busy and self.COMPLIANCE_IS_BUSY_FLAG, newreq_output_len_sample, has_out_len_factor=1.0
71+
)
72+
)
73+
need_max_token_nums.append(self._calc_max_token_num_needed(li))
74+
need_max_token_num = np.max(need_max_token_nums)
75+
return need_max_token_num
76+
77+
def record_finished_len_from_batch(self, batch: Batch):
78+
for req in batch.reqs:
79+
if req.shm_infer_released:
80+
self.history_output_len.append(req.shm_cur_output_len)

0 commit comments

Comments
 (0)