Skip to content

Commit 6a81baa

Browse files
hiworldwzjwangzaijun
andauthored
fix health req id gen when httpserver worker num > 1 (#1137)
Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
1 parent 080ae7e commit 6a81baa

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

lightllm/server/req_id_generator.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import psutil
3+
import sys
14
import time
25
import requests
36
import numpy as np
@@ -27,6 +30,44 @@ def __init__(self):
2730
self.current_id.arr[0] = 0
2831
self.current_id.arr[1] = 0
2932
self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock")
33+
self._wait_all_workers_ready()
34+
logger.info("ReqIDGenerator init finished")
35+
36+
def _wait_all_workers_ready(self):
37+
from lightllm.utils.envs_utils import get_unique_server_name
38+
from lightllm.server.core.objs.shm_array import ShmArray
39+
40+
_sync_shm = ShmArray(
41+
f"{get_unique_server_name()}_httpworker_start_sync", (self.args.httpserver_workers,), dtype=np.int64
42+
)
43+
_sync_shm.create_shm()
44+
# 等待所有 httpserver 的 worker 启动完成,防止重新初始化对应的请求id 对应的shm
45+
try_count = 0
46+
while len(_find_sibling_processes()) + 1 != self.args.httpserver_workers:
47+
time.sleep(0.1)
48+
try_count += 1
49+
if try_count > 120:
50+
logger.error("wait all httpserver workers start failed")
51+
sys.exit(-1)
52+
else:
53+
continue
54+
55+
cur_p_id = os.getpid()
56+
pids = _find_sibling_processes()
57+
pids.append(cur_p_id)
58+
assert len(pids) == self.args.httpserver_workers
59+
pids = sorted(pids)
60+
index = pids.index(cur_p_id)
61+
_sync_shm.arr[index] = cur_p_id
62+
try_count = 0
63+
while not all(a == b for a, b in zip(pids, _sync_shm.arr)):
64+
time.sleep(0.1)
65+
try_count += 1
66+
if try_count > 120:
67+
logger.error("wait all httpserver workers start failed 1")
68+
sys.exit(-1)
69+
else:
70+
continue
3071

3172
def _check_and_set_new_id_range(self):
3273
need_update_range = self.current_id.arr[0] + MAX_BEST_OF >= self.current_id.arr[1]
@@ -66,3 +107,30 @@ def generate_id(self):
66107

67108
def convert_sub_id_to_group_id(sub_req_id):
68109
return (sub_req_id // MAX_BEST_OF) * MAX_BEST_OF
110+
111+
112+
def _find_sibling_processes():
113+
# 获取当前进程的 PID
114+
current_pid = os.getpid()
115+
116+
# 获取当前进程的信息
117+
current_process = psutil.Process(current_pid)
118+
119+
# 获取当前进程的父进程
120+
parent_process = current_process.parent()
121+
122+
if parent_process is None:
123+
logger.error("Current process has no parent.")
124+
return []
125+
126+
# 查找兄弟进程
127+
sibling_processes = []
128+
for proc in psutil.process_iter(["pid", "name"]):
129+
try:
130+
# 检查是否是兄弟进程(同一父进程且不是当前进程)
131+
if proc.pid != current_pid and proc.ppid() == parent_process.pid:
132+
sibling_processes.append(proc)
133+
except (psutil.NoSuchProcess, psutil.AccessDenied):
134+
continue
135+
136+
return [proc.pid for proc in sibling_processes]

lightllm/utils/health_check.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@
88
from lightllm.server.httpserver.manager import HttpServerManager
99
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
1010
from fastapi import Request
11-
from lightllm.server.req_id_generator import ReqIDGenerator
1211
from lightllm.utils.log_utils import init_logger
1312
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1413

1514
logger = init_logger(__name__)
1615

1716

18-
_g_health_req_id_gen = ReqIDGenerator()
19-
20-
2117
@dataclass
2218
class HealthObj:
2319
_is_health: bool = False
@@ -81,9 +77,9 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req
8177
if get_env_start_args().run_mode == "pd_master":
8278
# Since the id assigned by pd master needs to be passed to prefill and decode nodes for inference,
8379
# a normal request id is required instead of a negative id.
84-
sampling_params.group_request_id = _g_health_req_id_gen.generate_id()
80+
sampling_params.group_request_id = httpserver_manager.id_gen.generate_id()
8581
else:
86-
sampling_params.group_request_id = -_g_health_req_id_gen.generate_id() # health monitor 的 id 是负的
82+
sampling_params.group_request_id = -httpserver_manager.id_gen.generate_id() # health monitor 的 id 是负的
8783
multimodal_params_dict = request_dict.get("multimodal_params", {})
8884
multimodal_params = MultimodalParams(**multimodal_params_dict)
8985
results_generator = httpserver_manager.generate(

0 commit comments

Comments
 (0)