Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions lightllm/server/req_id_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import psutil
import sys
import time
import requests
import numpy as np
Expand Down Expand Up @@ -27,6 +30,44 @@ def __init__(self):
self.current_id.arr[0] = 0
self.current_id.arr[1] = 0
self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock")
self._wait_all_workers_ready()
logger.info("ReqIDGenerator init finished")

def _wait_all_workers_ready(self):
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.core.objs.shm_array import ShmArray

_sync_shm = ShmArray(
f"{get_unique_server_name()}_httpworker_start_sync", (self.args.httpserver_workers,), dtype=np.int64
)
_sync_shm.create_shm()
# 等待所有 httpserver 的 worker 启动完成,防止重新初始化对应的请求id 对应的shm
try_count = 0
while len(_find_sibling_processes()) + 1 != self.args.httpserver_workers:
time.sleep(0.1)
try_count += 1
if try_count > 120:
logger.error("wait all httpserver workers start failed")
sys.exit(-1)
else:
continue

cur_p_id = os.getpid()
pids = _find_sibling_processes()
pids.append(cur_p_id)
assert len(pids) == self.args.httpserver_workers
pids = sorted(pids)
index = pids.index(cur_p_id)
_sync_shm.arr[index] = cur_p_id
try_count = 0
while not all(a == b for a, b in zip(pids, _sync_shm.arr)):
time.sleep(0.1)
try_count += 1
if try_count > 120:
logger.error("wait all httpserver workers start failed 1")
sys.exit(-1)
else:
continue

def _check_and_set_new_id_range(self):
need_update_range = self.current_id.arr[0] + MAX_BEST_OF >= self.current_id.arr[1]
Expand Down Expand Up @@ -66,3 +107,30 @@ def generate_id(self):

def convert_sub_id_to_group_id(sub_req_id):
return (sub_req_id // MAX_BEST_OF) * MAX_BEST_OF


def _find_sibling_processes():
# 获取当前进程的 PID
current_pid = os.getpid()

# 获取当前进程的信息
current_process = psutil.Process(current_pid)

# 获取当前进程的父进程
parent_process = current_process.parent()

if parent_process is None:
logger.error("Current process has no parent.")
return []

# 查找兄弟进程
sibling_processes = []
for proc in psutil.process_iter(["pid", "name"]):
try:
# 检查是否是兄弟进程(同一父进程且不是当前进程)
if proc.pid != current_pid and proc.ppid() == parent_process.pid:
sibling_processes.append(proc)
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue

return [proc.pid for proc in sibling_processes]
8 changes: 2 additions & 6 deletions lightllm/utils/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@
from lightllm.server.httpserver.manager import HttpServerManager
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from fastapi import Request
from lightllm.server.req_id_generator import ReqIDGenerator
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args

logger = init_logger(__name__)


_g_health_req_id_gen = ReqIDGenerator()


@dataclass
class HealthObj:
_is_health: bool = False
Expand Down Expand Up @@ -81,9 +77,9 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req
if get_env_start_args().run_mode == "pd_master":
# Since the id assigned by pd master needs to be passed to prefill and decode nodes for inference,
# a normal request id is required instead of a negative id.
sampling_params.group_request_id = _g_health_req_id_gen.generate_id()
sampling_params.group_request_id = httpserver_manager.id_gen.generate_id()
else:
sampling_params.group_request_id = -_g_health_req_id_gen.generate_id() # health monitor 的 id 是负的
sampling_params.group_request_id = -httpserver_manager.id_gen.generate_id() # health monitor 的 id 是负的
multimodal_params_dict = request_dict.get("multimodal_params", {})
multimodal_params = MultimodalParams(**multimodal_params_dict)
results_generator = httpserver_manager.generate(
Expand Down