diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index 445b6213e..9bf9040c3 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -1,3 +1,6 @@ +import os +import psutil +import sys import time import requests import numpy as np @@ -27,6 +30,44 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") + self._wait_all_workers_ready() + logger.info("ReqIDGenerator init finished") + + def _wait_all_workers_ready(self): + from lightllm.utils.envs_utils import get_unique_server_name + from lightllm.server.core.objs.shm_array import ShmArray + + _sync_shm = ShmArray( + f"{get_unique_server_name()}_httpworker_start_sync", (self.args.httpserver_workers,), dtype=np.int64 + ) + _sync_shm.create_shm() + # 等待所有 httpserver 的 worker 启动完成,防止重新初始化对应的请求id 对应的shm + try_count = 0 + while len(_find_sibling_processes()) + 1 != self.args.httpserver_workers: + time.sleep(0.1) + try_count += 1 + if try_count > 120: + logger.error("wait all httpserver workers start failed") + sys.exit(-1) + else: + continue + + cur_p_id = os.getpid() + pids = _find_sibling_processes() + pids.append(cur_p_id) + assert len(pids) == self.args.httpserver_workers + pids = sorted(pids) + index = pids.index(cur_p_id) + _sync_shm.arr[index] = cur_p_id + try_count = 0 + while not all(a == b for a, b in zip(pids, _sync_shm.arr)): + time.sleep(0.1) + try_count += 1 + if try_count > 120: + logger.error("wait all httpserver workers start failed 1") + sys.exit(-1) + else: + continue def _check_and_set_new_id_range(self): need_update_range = self.current_id.arr[0] + MAX_BEST_OF >= self.current_id.arr[1] @@ -66,3 +107,30 @@ def generate_id(self): def convert_sub_id_to_group_id(sub_req_id): return (sub_req_id // MAX_BEST_OF) * MAX_BEST_OF + + +def _find_sibling_processes(): + # 获取当前进程的 PID + current_pid = os.getpid() + + # 获取当前进程的信息 + current_process = psutil.Process(current_pid) + + # 获取当前进程的父进程 + parent_process = current_process.parent() + + if parent_process is None: + logger.error("Current process has no parent.") + return [] + + # 查找兄弟进程 + sibling_processes = [] + for proc in psutil.process_iter(["pid", "name"]): + try: + # 检查是否是兄弟进程(同一父进程且不是当前进程) + if proc.pid != current_pid and proc.ppid() == parent_process.pid: + sibling_processes.append(proc) + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + return [proc.pid for proc in sibling_processes] diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index ee0778b65..f6c52bdb3 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -8,16 +8,12 @@ from lightllm.server.httpserver.manager import HttpServerManager from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from fastapi import Request -from lightllm.server.req_id_generator import ReqIDGenerator from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args logger = init_logger(__name__) -_g_health_req_id_gen = ReqIDGenerator() - - @dataclass class HealthObj: _is_health: bool = False @@ -81,9 +77,9 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req if get_env_start_args().run_mode == "pd_master": # Since the id assigned by pd master needs to be passed to prefill and decode nodes for inference, # a normal request id is required instead of a negative id. - sampling_params.group_request_id = _g_health_req_id_gen.generate_id() + sampling_params.group_request_id = httpserver_manager.id_gen.generate_id() else: - sampling_params.group_request_id = -_g_health_req_id_gen.generate_id() # health monitor 的 id 是负的 + sampling_params.group_request_id = -httpserver_manager.id_gen.generate_id() # health monitor 的 id 是负的 multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) results_generator = httpserver_manager.generate(