Skip to content

Commit c6d977e

Browse files
SiYu WuWuSiYu
authored andcommitted
feat(misc): Profiler support
use --enable_profiling=MODE to enable, currently support torch_profile and nvtx (use with NVIDIA Nsight system) mode
1 parent d6d59ec commit c6d977e

File tree

8 files changed

+262
-7
lines changed

8 files changed

+262
-7
lines changed

lightllm/server/api_cli.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,4 +537,19 @@ def make_argument_parser() -> argparse.ArgumentParser:
537537
parser.add_argument(
538538
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
539539
)
540+
parser.add_argument(
541+
"--enable_profiling",
542+
type=str,
543+
choices=["torch_profiler", "nvtx"],
544+
default=None,
545+
help="""Enable profiler support.
546+
This will expose '/profiler_start' and '/profiler_stop' API,
547+
below profiling features will only been enabled in this range.
548+
Options:
549+
'torch_profiler': will setup torch.profiler.profile(), traces file will been saved to './trace',
550+
or set by 'LIGHTLLM_TRACE_DIR' env;
551+
'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System
552+
(you should setup it by youself).
553+
A NVTX named 'LIGHTLLM_PROFILE' will been added within the profiling range.""",
554+
)
540555
return parser

lightllm/server/api_http.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,24 @@ async def kv_move_status(websocket: WebSocket):
335335
return
336336

337337

338+
@app.get("/profiler_start")
339+
async def profiler_start() -> Response:
340+
if g_objs.args.enable_profiling:
341+
await g_objs.httpserver_manager.profiler_cmd("start")
342+
return {"status": "ok"}
343+
else:
344+
return JSONResponse({"message": "Profiling support not enabled"}, status_code=500)
345+
346+
347+
@app.get("/profiler_stop")
348+
async def profiler_stop() -> Response:
349+
if g_objs.args.enable_profiling:
350+
await g_objs.httpserver_manager.profiler_cmd("stop")
351+
return {"status": "ok"}
352+
else:
353+
return JSONResponse({"message": "Profiling support not enabled"}, status_code=500)
354+
355+
338356
@app.on_event("shutdown")
339357
async def shutdown():
340358
logger.info("Received signal to shutdown. Performing graceful shutdown...")

lightllm/server/httpserver/manager.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from frozendict import frozendict
1414

1515
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
16-
from typing import Union, List, Tuple, Dict, Optional
16+
from typing import Literal, Union, List, Tuple, Dict, Optional
1717
from websockets import ClientConnection
1818
from fastapi import Request
1919
from ..tokenizer import get_tokenizer
@@ -35,6 +35,7 @@
3535
from lightllm.utils.config_utils import get_vocab_size
3636
from lightllm.utils.envs_utils import get_unique_server_name
3737
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
38+
from lightllm.utils.profiler import ProfilerCmd
3839
from rpyc.utils.classic import obtain
3940

4041
logger = init_logger(__name__)
@@ -642,6 +643,16 @@ async def abort(self, group_req_id: int) -> bool:
642643
logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}")
643644
return True
644645

646+
async def profiler_cmd(self, cmd: Literal["start", "stop"]):
647+
receivers = [self.send_to_router]
648+
if self.pd_mode.is_P_or_NORMAL() and self.enable_multimodal:
649+
receivers.append(self.send_to_visual)
650+
for receiver in receivers:
651+
receiver.send_pyobj(
652+
ProfilerCmd(cmd),
653+
protocol=pickle.HIGHEST_PROTOCOL,
654+
)
655+
645656
async def recycle_resource_loop(self):
646657
pre_time_mark = time.time()
647658

lightllm/server/router/manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
2727
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
2828
from lightllm.utils.log_utils import init_logger, log_time_ready
29+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
2930
from lightllm.server.router.token_load import TokenLoad
3031
from lightllm.server.metrics.manager import MetricClient
3132
from lightllm.common.basemodel.infer_lock import g_router_lock
@@ -106,6 +107,10 @@ def __init__(self, args: StartArgs):
106107
if not self.args.enable_cpu_cache
107108
else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False)
108109
)
110+
111+
self.profiler = (
112+
ProcessProfiler(mode=args.enable_profiling, name="lightllm-router") if args.enable_profiling else None
113+
)
109114
return
110115

111116
async def wait_to_model_ready(self):
@@ -508,16 +513,28 @@ def _multinode_tp_generate_new_batch(self):
508513
raise e
509514
return
510515

516+
async def _profiler_cmd(self, cmd_obj: ProfilerCmd):
517+
self.profiler.cmd(cmd_obj)
518+
519+
cmd = ProfilerCmd(cmd=cmd_obj.cmd)
520+
while not self.shm_reqs_io_buffer.is_empty():
521+
await asyncio.sleep(0.02)
522+
523+
self.shm_reqs_io_buffer.write_obj([cmd])
524+
self.shm_reqs_io_buffer.set_ready()
525+
511526
async def _recv_new_reqs_and_schedule(self):
512527
if not hasattr(self, "recv_max_count"):
513528
self.recv_max_count = 64
514529

515530
try:
516531
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
517532
for _ in range(self.recv_max_count):
518-
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
533+
recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
519534
if isinstance(recv_req, GroupReqIndexes):
520535
self._add_req(recv_req)
536+
elif isinstance(recv_req, ProfilerCmd):
537+
await self._profiler_cmd(recv_req)
521538
else:
522539
assert False, f"Error Req Inf {recv_req}"
523540

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import threading
66
import torch.distributed as dist
7-
from typing import List, Tuple, Callable, Optional
7+
from typing import Dict, List, Literal, Tuple, Callable, Optional
88
from transformers.configuration_utils import PretrainedConfig
99
from lightllm.utils.infer_utils import set_random_seed
1010
from lightllm.utils.log_utils import init_logger
@@ -39,6 +39,7 @@
3939
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
4040
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
4141
from .multi_level_kv_cache import MultiLevelKvCacheModule
42+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
4243

4344

4445
class ModeBackend:
@@ -218,11 +219,19 @@ def init_model(self, kvargs):
218219
if self.args.mtp_mode:
219220
self.init_mtp_draft_model(kvargs)
220221

222+
self.profiler: Optional[ProcessProfiler] = None
223+
if self.args.enable_profiling:
224+
self.profiler = ProcessProfiler(
225+
mode=self.args.enable_profiling,
226+
name=f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}",
227+
)
228+
self.profiling_active = False
229+
221230
# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
222231
# 可以降低 cpu overhead,大幅提升gpu得使用率。
223-
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True)
232+
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True, name="loop0")
224233
self.infer_loop_thread.start()
225-
self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True)
234+
self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True, name="loop1")
226235
self.infer_loop_thread1.start()
227236
return
228237

@@ -308,6 +317,14 @@ def _try_read_new_reqs(self):
308317
self._try_read_new_reqs_multinode_tp()
309318
else:
310319
self._try_read_new_reqs_normal()
320+
321+
# on each loop thread
322+
if self.profiler is not None:
323+
if self.profiler.is_active != self.profiling_active:
324+
if self.profiling_active:
325+
self.profiler.start()
326+
else:
327+
self.profiler.stop()
311328
return
312329

313330
def _try_read_new_reqs_normal(self):
@@ -373,6 +390,11 @@ def _read_reqs_buffer_and_init_reqs(self):
373390
if obj.req_id in g_infer_context.requests_mapping:
374391
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
375392
req.infer_aborted = True
393+
elif isinstance(obj, ProfilerCmd):
394+
if obj.cmd == "start":
395+
self.profiling_active = True
396+
elif obj.cmd == "stop":
397+
self.profiling_active = False
376398
else:
377399
assert False, f"error type {type(obj)}"
378400
if init_reqs:

lightllm/server/visualserver/manager.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pickle
88
import inspect
99
import setproctitle
10-
from typing import List
10+
from typing import List, Union
1111
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
1212
from lightllm.server.core.objs import ShmReqManager, StartArgs
1313

@@ -18,6 +18,7 @@
1818
from lightllm.utils.graceful_utils import graceful_registry
1919
from lightllm.utils.process_check import start_parent_check_thread
2020
from lightllm.utils.envs_utils import get_unique_server_name
21+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
2122
from rpyc.utils.classic import obtain
2223

2324

@@ -58,6 +59,9 @@ def __init__(
5859
self.args = args
5960
self.visual_model_rpc_ports = visual_model_rpc_ports
6061
self.shm_req_manager = ShmReqManager()
62+
self.profiler: "ProcessProfiler|None" = (
63+
ProcessProfiler(args.enable_profiling, name="lightllm-visual_server") if args.enable_profiling else None
64+
)
6165

6266
async def wait_to_model_ready(self):
6367

@@ -90,6 +94,7 @@ async def wait_to_model_ready(self):
9094
"quant_type": self.args.vit_quant_type,
9195
"quant_cfg": self.args.vit_quant_cfg,
9296
"max_batch_size": min(self.infer_batch_size // self.vit_dp, 1),
97+
"profiler": self.args.enable_profiling,
9398
}
9499
init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs))
95100
await asyncio.gather(*init_model_ret)
@@ -171,9 +176,19 @@ async def loop_for_netio_req(self):
171176
while True:
172177
try:
173178
for _ in range(self.visual_recv_max_count):
174-
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
179+
recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
175180
if isinstance(recv_req, GroupReqIndexes):
176181
self.waiting_reqs.append(recv_req)
182+
elif isinstance(recv_req, ProfilerCmd):
183+
self.profiler.cmd(recv_req)
184+
tasks = []
185+
for vit_dp_rank in range(self.vit_dp):
186+
for vit_tp_rank in range(self.vit_tp):
187+
task = asyncio.create_task(
188+
self.model_rpcs[vit_dp_rank][vit_tp_rank].profiler_cmd(recv_req)
189+
)
190+
tasks.append(task)
191+
await asyncio.gather(*tasks)
177192
else:
178193
assert False, f"Error Req Inf {recv_req}"
179194
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightllm.utils.dist_utils import init_vision_distributed_env
2525
from lightllm.utils.graceful_utils import graceful_registry
2626
from lightllm.utils.envs_utils import get_env_start_args
27+
from lightllm.utils.profiler import ProcessProfiler
2728

2829

2930
class VisualModelRpcServer(rpyc.Service):
@@ -42,6 +43,13 @@ def exposed_init_model(self, kvargs):
4243
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
4344
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
4445
self.data_type = kvargs["data_type"]
46+
self.profiler = (
47+
ProcessProfiler(
48+
mode=kvargs["profiler"], name=f"lightllm-visual-vit_dp{self.dp_rank_id}_tp{self.tp_rank_id}"
49+
)
50+
if kvargs["profiler"]
51+
else None
52+
)
4553

4654
init_vision_distributed_env(kvargs)
4755
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)
@@ -116,6 +124,10 @@ def exposed_encode(self, images: List[ImageItem]):
116124
self.cache_client.root.set_items_embed(ids_to_set)
117125
return
118126

127+
def exposed_profiler_cmd(self, cmd_obj):
128+
cmd_obj = obtain(cmd_obj)
129+
self.profiler.cmd(cmd_obj)
130+
119131

120132
class VisualModelRpcClient:
121133
def __init__(self, model_rpc, vit_tp, rpc_server_process=None):
@@ -138,9 +150,11 @@ async def _func(*args, **kwargs):
138150

139151
self._init_model = async_wrap(self.model.init_model)
140152
self._encode = async_wrap(self.model.encode)
153+
self._profiler_cmd = async_wrap(self.model.profiler_cmd)
141154
else:
142155
self._init_model = self.model.exposed_init_model
143156
self._encode = self.model.exposed_encode
157+
self._profiler_cmd = self.model.exposed_profiler_cmd
144158
return
145159

146160
async def init_model(self, kvargs):
@@ -158,6 +172,14 @@ async def encode(self, images: List[ImageItem]):
158172
else:
159173
return ans
160174

175+
async def profiler_cmd(self, cmd_obj):
176+
ans: rpyc.AsyncResult = self._profiler_cmd(cmd_obj)
177+
if self.use_rpc:
178+
await ans
179+
return
180+
else:
181+
return
182+
161183

162184
def _init_env(port, device_id):
163185
# 注册graceful 退出的处理

0 commit comments

Comments
 (0)