|
| 1 | +import os |
| 2 | +import psutil |
| 3 | +import sys |
1 | 4 | import time |
2 | 5 | import requests |
3 | 6 | import numpy as np |
@@ -27,6 +30,44 @@ def __init__(self): |
27 | 30 | self.current_id.arr[0] = 0 |
28 | 31 | self.current_id.arr[1] = 0 |
29 | 32 | self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") |
| 33 | + self._wait_all_workers_ready() |
| 34 | + logger.info("ReqIDGenerator init finished") |
| 35 | + |
| 36 | + def _wait_all_workers_ready(self): |
| 37 | + from lightllm.utils.envs_utils import get_unique_server_name |
| 38 | + from lightllm.server.core.objs.shm_array import ShmArray |
| 39 | + |
| 40 | + _sync_shm = ShmArray( |
| 41 | + f"{get_unique_server_name()}_httpworker_start_sync", (self.args.httpserver_workers,), dtype=np.int64 |
| 42 | + ) |
| 43 | + _sync_shm.create_shm() |
| 44 | + # 等待所有 httpserver 的 worker 启动完成,防止重新初始化对应的请求id 对应的shm |
| 45 | + try_count = 0 |
| 46 | + while len(_find_sibling_processes()) + 1 != self.args.httpserver_workers: |
| 47 | + time.sleep(0.1) |
| 48 | + try_count += 1 |
| 49 | + if try_count > 120: |
| 50 | + logger.error("wait all httpserver workers start failed") |
| 51 | + sys.exit(-1) |
| 52 | + else: |
| 53 | + continue |
| 54 | + |
| 55 | + cur_p_id = os.getpid() |
| 56 | + pids = _find_sibling_processes() |
| 57 | + pids.append(cur_p_id) |
| 58 | + assert len(pids) == self.args.httpserver_workers |
| 59 | + pids = sorted(pids) |
| 60 | + index = pids.index(cur_p_id) |
| 61 | + _sync_shm.arr[index] = cur_p_id |
| 62 | + try_count = 0 |
| 63 | + while not all(a == b for a, b in zip(pids, _sync_shm.arr)): |
| 64 | + time.sleep(0.1) |
| 65 | + try_count += 1 |
| 66 | + if try_count > 120: |
| 67 | + logger.error("wait all httpserver workers start failed 1") |
| 68 | + sys.exit(-1) |
| 69 | + else: |
| 70 | + continue |
30 | 71 |
|
31 | 72 | def _check_and_set_new_id_range(self): |
32 | 73 | need_update_range = self.current_id.arr[0] + MAX_BEST_OF >= self.current_id.arr[1] |
@@ -66,3 +107,30 @@ def generate_id(self): |
66 | 107 |
|
67 | 108 | def convert_sub_id_to_group_id(sub_req_id): |
68 | 109 | return (sub_req_id // MAX_BEST_OF) * MAX_BEST_OF |
| 110 | + |
| 111 | + |
| 112 | +def _find_sibling_processes(): |
| 113 | + # 获取当前进程的 PID |
| 114 | + current_pid = os.getpid() |
| 115 | + |
| 116 | + # 获取当前进程的信息 |
| 117 | + current_process = psutil.Process(current_pid) |
| 118 | + |
| 119 | + # 获取当前进程的父进程 |
| 120 | + parent_process = current_process.parent() |
| 121 | + |
| 122 | + if parent_process is None: |
| 123 | + logger.error("Current process has no parent.") |
| 124 | + return [] |
| 125 | + |
| 126 | + # 查找兄弟进程 |
| 127 | + sibling_processes = [] |
| 128 | + for proc in psutil.process_iter(["pid", "name"]): |
| 129 | + try: |
| 130 | + # 检查是否是兄弟进程(同一父进程且不是当前进程) |
| 131 | + if proc.pid != current_pid and proc.ppid() == parent_process.pid: |
| 132 | + sibling_processes.append(proc) |
| 133 | + except (psutil.NoSuchProcess, psutil.AccessDenied): |
| 134 | + continue |
| 135 | + |
| 136 | + return [proc.pid for proc in sibling_processes] |
0 commit comments