diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index bf0e89887..07f70fb66 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -572,4 +572,19 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument( + "--enable_profiling", + type=str, + choices=["torch_profiler", "nvtx"], + default=None, + help="""Enable profiler support. + This will expose '/profiler_start' and '/profiler_stop' API, + below profiling features will only be enabled in this range. + Options: + 'torch_profiler': will setup torch.profiler.profile(), trace files will be saved to './trace', + or set by 'LIGHTLLM_TRACE_DIR' env; + 'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System + (you should set it up by yourself). + A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8bda50fb7..eddce01a9 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -335,6 +335,24 @@ async def kv_move_status(websocket: WebSocket): return +@app.get("/profiler_start") +async def profiler_start() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("start") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + +@app.get("/profiler_stop") +async def profiler_stop() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("stop") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + @app.on_event("shutdown") async def shutdown(): logger.info("Received signal to shutdown. Performing graceful shutdown...") diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e1cb32b88..3fac5bc0a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,7 +13,7 @@ from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator +from typing import Literal, Union, List, Tuple, Dict, Optional, AsyncGenerator from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer @@ -35,6 +35,7 @@ from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.profiler import ProfilerCmd from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -650,6 +651,16 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def profiler_cmd(self, cmd: Literal["start", "stop"]): + receivers = [self.send_to_router] + if self.pd_mode.is_P_or_NORMAL() and self.enable_multimodal: + receivers.append(self.send_to_visual) + for receiver in receivers: + receiver.send_pyobj( + ProfilerCmd(cmd), + protocol=pickle.HIGHEST_PROTOCOL, + ) + async def recycle_resource_loop(self): pre_time_mark = time.time() diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 89c46d9ed..c12865f4f 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -26,6 +26,7 @@ from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock @@ -106,6 +107,9 @@ def __init__(self, args: StartArgs): if not self.args.enable_cpu_cache else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) + + profiler_mode = args.enable_profiling + self.profiler = ProcessProfiler(mode=profiler_mode, name="lightllm-router") if profiler_mode else None return async def wait_to_model_ready(self): @@ -504,6 +508,16 @@ def _multinode_tp_generate_new_batch(self): raise e return + async def _profiler_cmd(self, cmd_obj: ProfilerCmd): + self.profiler.cmd(cmd_obj) + + cmd = ProfilerCmd(cmd=cmd_obj.cmd) + while not self.shm_reqs_io_buffer.is_empty(): + await asyncio.sleep(0.02) + + self.shm_reqs_io_buffer.write_obj([cmd]) + self.shm_reqs_io_buffer.set_ready() + async def _recv_new_reqs_and_schedule(self): if not hasattr(self, "recv_max_count"): self.recv_max_count = 64 @@ -511,9 +525,11 @@ async def _recv_new_reqs_and_schedule(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self._add_req(recv_req) + elif isinstance(recv_req, ProfilerCmd): + await self._profiler_cmd(recv_req) else: assert False, f"Error Req Inf {recv_req}" diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a780c4da0..81fd1d283 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -39,6 +39,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd class ModeBackend: @@ -231,6 +232,10 @@ def init_model(self, kvargs): if self.args.mtp_mode: self.init_mtp_draft_model(kvargs) + prof_name = f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}" + prof_mode = self.args.enable_profiling + self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name, use_multi_thread=True) if prof_mode else None + # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) @@ -343,6 +348,10 @@ def _try_read_new_reqs(self): self._try_read_new_reqs_multinode_tp() else: self._try_read_new_reqs_normal() + + # on each loop thread + if self.profiler is not None: + self.profiler.multi_thread_helper() return def _try_read_new_reqs_normal(self): @@ -408,6 +417,8 @@ def _read_reqs_buffer_and_init_reqs(self): if obj.req_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.req_id] req.infer_aborted = True + elif isinstance(obj, ProfilerCmd): + self.profiler.cmd(obj) else: assert False, f"error type {type(obj)}" if init_reqs: diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a389272e5..3e02b4e1b 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -7,7 +7,7 @@ import pickle import inspect import setproctitle -from typing import List +from typing import List, Union from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -18,6 +18,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd from rpyc.utils.classic import obtain @@ -59,6 +60,8 @@ def __init__( self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() + prof_mode = args.enable_profiling + self.profiler = ProcessProfiler(prof_mode, name="lightllm-visual_server") if prof_mode else None async def wait_to_model_ready(self): @@ -185,9 +188,17 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes | ProfilerCmd = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) + elif isinstance(recv_req, ProfilerCmd): + self.profiler.cmd(recv_req) + tasks = [] + for dp in range(self.vit_dp): + for tp in range(self.vit_tp): + task = asyncio.create_task(self.model_rpcs[dp][tp].profiler_cmd(recv_req)) + tasks.append(task) + await asyncio.gather(*tasks) else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a9409ceb9..3582fb392 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,6 +24,7 @@ from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.profiler import ProcessProfiler class VisualModelRpcServer(rpyc.Service): @@ -43,6 +44,9 @@ def exposed_init_model(self, kvargs): self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] + prof_mode = get_env_start_args().enable_profiling + prof_name = f"lightllm-visual-vit_dp{self.dp_rank_id}_tp{self.tp_rank_id}" + self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name) if prof_mode else None init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -116,6 +120,10 @@ def exposed_encode(self, images: List[ImageItem]): self.cache_client.root.set_items_embed(ids_to_set) return + def exposed_profiler_cmd(self, cmd_obj): + cmd_obj = obtain(cmd_obj) + self.profiler.cmd(cmd_obj) + class VisualModelRpcClient: def __init__(self, model_rpc, vit_tp, rpc_server_process=None): @@ -138,9 +146,11 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.model.init_model) self._encode = async_wrap(self.model.encode) + self._profiler_cmd = async_wrap(self.model.profiler_cmd) else: self._init_model = self.model.exposed_init_model self._encode = self.model.exposed_encode + self._profiler_cmd = self.model.exposed_profiler_cmd return async def init_model(self, kvargs): @@ -158,6 +168,14 @@ async def encode(self, images: List[ImageItem]): else: return ans + async def profiler_cmd(self, cmd_obj): + ans: rpyc.AsyncResult = self._profiler_cmd(cmd_obj) + if self.use_rpc: + await ans + return + else: + return + def _init_env(port, device_id): # 注册graceful 退出的处理 diff --git a/lightllm/utils/profiler.py b/lightllm/utils/profiler.py new file mode 100644 index 000000000..6ed23dced --- /dev/null +++ b/lightllm/utils/profiler.py @@ -0,0 +1,227 @@ +from dataclasses import dataclass +import os +import threading +import traceback +from typing import Any, Literal, Optional +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class ProfilerCmd: + cmd: Literal["start", "stop"] + + +def _get_thread_id() -> int: + # Get native thread ID (LWP) for correlation with system tools like htop/nsys + if hasattr(threading, "get_native_id"): + return threading.get_native_id() + return threading.get_ident() + + +class ProcessProfiler: + def __init__( + self, + mode: Literal["torch_profiler", "nvtx"], + name: Optional[str] = None, + use_multi_thread: bool = False, + torch_profiler_with_stack: bool = True, + ) -> None: + """ + Process Level Profiler Manager. + For multi-threading, set `use_multi_thread=True` + and call `.multi_thread_helper()` regularly in each worker thread. + """ + self.mode = mode + self.name = name or "unnamed" + self.use_multi_thread = use_multi_thread + self.torch_profiler_with_stack = torch_profiler_with_stack + + self.is_active: bool = False # Process-level logical state + self._threadlocal = threading.local() + + # make sure only one active torch.profiler per process + self._lock = threading.Lock() + self._process_torch_profiler_active_tid: int | None = None + + if self.mode == "torch_profiler": + self._trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace") + os.makedirs(self._trace_dir, exist_ok=True) + elif self.mode == "nvtx": + self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE" + else: + raise ValueError("invalid profiler mode") + + self._log_init_info() + + @property + def _local(self): + """Lazy initialization of thread-local storage.""" + if not hasattr(self._threadlocal, "initialized"): + self._threadlocal.initialized = True + self._threadlocal.is_active = False + self._threadlocal.profiler_obj = None + self._threadlocal.nvtx_range_id = None + return self._threadlocal + + def _log_init_info(self): + logger.warning("-" * 50) + logger.warning( + f"[pid={os.getpid()} tid={_get_thread_id()}] Profiler <{self.name}> initialized with mode: {self.mode}" + ) + if self.mode == "torch_profiler": + logger.warning( + "Profiler support for torch.profiler enabled (--enable_profiling=torch_profiler), " + "trace files will be saved to %s (change it with LIGHTLLM_TRACE_DIR env var)", + self._trace_dir, + ) + elif self.mode == "nvtx": + logger.warning( + "Profiler support for NVTX enabled (--enable_profiling=nvtx), toplevel NVTX mark is '%s'\n" + "you can use it with external profiling tools like NVIDIA Nsight Systems.", + self._nvtx_toplevel_mark, + ) + logger.warning( + "e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx " + "-e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [other nsys options] " + "python -m lightllm.server.api_server --enable_profiling=nvtx [other lightllm options]", + self._nvtx_toplevel_mark, + ) + logger.warning("Use /profiler_start and /profiler_stop HTTP GET APIs to start/stop profiling") + logger.warning("DO NOT enable this feature in production environment") + logger.warning("-" * 50) + + def _torch_profiler_start(self) -> None: + with self._lock: + if self._process_torch_profiler_active_tid is not None: + return + self._process_torch_profiler_active_tid = _get_thread_id() + + torch.cuda.synchronize() + worker_name = f"{self.name}_tid{_get_thread_id()}" if self.use_multi_thread else self.name + + trace_handler = torch.profiler.tensorboard_trace_handler( + self._trace_dir, + worker_name=worker_name, + use_gzip=True, + ) + + p = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=None, + with_stack=self.torch_profiler_with_stack, + record_shapes=True, + on_trace_ready=trace_handler, + ) + + self._local.profiler_obj = p + p.start() + torch.cuda.synchronize() + + def _nvtx_start(self) -> None: + torch.cuda.synchronize() + self._local.nvtx_range_id = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark) + torch.cuda.synchronize() + + def _thread_start(self) -> None: + if self._local.is_active: + return + + try: + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Start Profiler.") + if self.mode == "torch_profiler": + self._torch_profiler_start() + elif self.mode == "nvtx": + self._nvtx_start() + + self._local.is_active = True + except Exception as e: + logger.error( + f"[{self.name} @ tid={_get_thread_id()}] Failed to start profiler in thread {_get_thread_id()}: {e}" + ) + traceback.print_exc() + # Reset state on failure to prevent infinite retry loops + self._local.is_active = False + + def _torch_profiler_stop(self) -> None: + if self._process_torch_profiler_active_tid != _get_thread_id(): + return + + torch.cuda.synchronize() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Saving trace (blocking)...") + try: + if self._local.profiler_obj: + self._local.profiler_obj.stop() + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Error stopping torch profiler: {e}") + finally: + self._local.profiler_obj = None # Explicitly release reference to allow GC + self._process_torch_profiler_active_tid = None + + torch.cuda.synchronize() + + def _nvtx_stop(self) -> None: + torch.cuda.synchronize() + if self._local.nvtx_range_id is not None: + torch.cuda.nvtx.range_end(self._local.nvtx_range_id) + self._local.nvtx_range_id = None + torch.cuda.synchronize() + + def _thread_stop(self) -> None: + if not self._local.is_active: + return + + try: + if self.mode == "torch_profiler": + self._torch_profiler_stop() + elif self.mode == "nvtx": + self._nvtx_stop() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Profiler stopped.") + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Failed to stop profiler: {e}") + finally: + # Mark inactive regardless of success to avoid repeated errors + self._local.is_active = False + + def start(self) -> None: + self.is_active = True + if not self.use_multi_thread: + self._thread_start() + + def stop(self) -> None: + self.is_active = False + if not self.use_multi_thread: + self._thread_stop() + + def multi_thread_helper(self) -> None: + """ + **only for multi-threading use cases** + Worker polling method. Must be called within the inference loop. + """ + if not self.use_multi_thread: + return + + # Catch-all to prevent profiler errors from crashing inference logic + try: + local_active = self._local.is_active + + if self.is_active and not local_active: + self._thread_start() + elif not self.is_active and local_active: + self._thread_stop() + except Exception: + pass + + def cmd(self, cmd_obj: ProfilerCmd) -> None: + if cmd_obj.cmd == "start": + self.start() + elif cmd_obj.cmd == "stop": + self.stop() + else: + raise ValueError(f"Invalid profiler cmd: {cmd_obj.cmd}")