|
4 | 4 | import time |
5 | 5 | import threading |
6 | 6 | import torch.distributed as dist |
7 | | -from typing import List, Tuple, Callable, Optional |
| 7 | +from typing import Dict, List, Literal, Tuple, Callable, Optional |
8 | 8 | from transformers.configuration_utils import PretrainedConfig |
9 | 9 | from lightllm.utils.infer_utils import set_random_seed |
10 | 10 | from lightllm.utils.log_utils import init_logger |
|
39 | 39 | from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token |
40 | 40 | from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet |
41 | 41 | from .multi_level_kv_cache import MultiLevelKvCacheModule |
| 42 | +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd |
42 | 43 |
|
43 | 44 |
|
44 | 45 | class ModeBackend: |
@@ -218,11 +219,19 @@ def init_model(self, kvargs): |
218 | 219 | if self.args.mtp_mode: |
219 | 220 | self.init_mtp_draft_model(kvargs) |
220 | 221 |
|
| 222 | + self.profiler: Optional[ProcessProfiler] = None |
| 223 | + if self.args.enable_profiling: |
| 224 | + self.profiler = ProcessProfiler( |
| 225 | + mode=self.args.enable_profiling, |
| 226 | + name=f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}", |
| 227 | + ) |
| 228 | + self.profiling_active = False |
| 229 | + |
221 | 230 | # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 |
222 | 231 | # 可以降低 cpu overhead,大幅提升gpu得使用率。 |
223 | | - self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) |
| 232 | + self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True, name="loop0") |
224 | 233 | self.infer_loop_thread.start() |
225 | | - self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True) |
| 234 | + self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True, name="loop1") |
226 | 235 | self.infer_loop_thread1.start() |
227 | 236 | return |
228 | 237 |
|
@@ -308,6 +317,14 @@ def _try_read_new_reqs(self): |
308 | 317 | self._try_read_new_reqs_multinode_tp() |
309 | 318 | else: |
310 | 319 | self._try_read_new_reqs_normal() |
| 320 | + |
| 321 | + # on each loop thread |
| 322 | + if self.profiler is not None: |
| 323 | + if self.profiler.is_active != self.profiling_active: |
| 324 | + if self.profiling_active: |
| 325 | + self.profiler.start() |
| 326 | + else: |
| 327 | + self.profiler.stop() |
311 | 328 | return |
312 | 329 |
|
313 | 330 | def _try_read_new_reqs_normal(self): |
@@ -373,6 +390,11 @@ def _read_reqs_buffer_and_init_reqs(self): |
373 | 390 | if obj.req_id in g_infer_context.requests_mapping: |
374 | 391 | req: InferReq = g_infer_context.requests_mapping[obj.req_id] |
375 | 392 | req.infer_aborted = True |
| 393 | + elif isinstance(obj, ProfilerCmd): |
| 394 | + if obj.cmd == "start": |
| 395 | + self.profiling_active = True |
| 396 | + elif obj.cmd == "stop": |
| 397 | + self.profiling_active = False |
376 | 398 | else: |
377 | 399 | assert False, f"error type {type(obj)}" |
378 | 400 | if init_reqs: |
|
0 commit comments