|
2 | 2 | import os |
3 | 3 | import torch |
4 | 4 | import torch.distributed as dist |
| 5 | +import torch.multiprocessing as mp |
5 | 6 | from typing import List, Union |
| 7 | +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp |
6 | 8 | from lightllm.server.pd_io_struct import KVMoveTask |
7 | 9 | from lightllm.utils.log_utils import init_logger |
8 | 10 | from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt |
9 | 11 | from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory |
10 | 12 | from lightllm.common.kv_trans_kernel.kv_trans import kv_trans |
11 | | -from lightllm.utils.dist_utils import get_current_rank_in_node |
| 13 | +from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size |
12 | 14 | from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args |
13 | 15 | from lightllm.distributed.pynccl import PyNcclCommunicator |
14 | 16 | from lightllm.utils.dist_utils import get_current_device_id |
15 | 17 | from lightllm.utils.config_utils import get_num_key_value_heads |
16 | 18 | from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io |
| 19 | +from lightllm.utils.device_utils import kv_trans_use_p2p |
| 20 | +from lightllm.utils.shm_utils import create_or_link_shm |
| 21 | +from multiprocessing.reduction import ForkingPickler |
| 22 | +from filelock import FileLock |
17 | 23 |
|
18 | 24 | logger = init_logger(__name__) |
19 | 25 |
|
@@ -401,6 +407,84 @@ def get_index_kv_buffer(self, index): |
401 | 407 | def load_index_kv_buffer(self, index, load_tensor_dict): |
402 | 408 | self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) |
403 | 409 |
|
| 410 | + def copy_kv_from_other_dp_ranks( |
| 411 | + self, |
| 412 | + mem_managers: List["MemoryManager"], |
| 413 | + move_token_indexes: torch.Tensor, |
| 414 | + token_dp_indexes: torch.Tensor, |
| 415 | + mem_indexes: torch.Tensor, |
| 416 | + dp_size_in_node: int, |
| 417 | + rank_in_dp: int, |
| 418 | + ): |
| 419 | + if not hasattr(self, "mem_ptrs_tensor"): |
| 420 | + # 构建一个2D tensor,shape为(layer_num, mem_num) |
| 421 | + mems_ptr_list = [] |
| 422 | + for i in range(0, len(mem_managers)): |
| 423 | + mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr()) |
| 424 | + self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True) |
| 425 | + |
| 426 | + # 一次性传输所有层 |
| 427 | + kv_trans_for_dp( |
| 428 | + input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True), |
| 429 | + input_idx=move_token_indexes, |
| 430 | + input_dp_idx=token_dp_indexes, |
| 431 | + output=self.kv_buffer, |
| 432 | + output_idx=mem_indexes, |
| 433 | + dp_size_in_node=dp_size_in_node, |
| 434 | + rank_in_dp=rank_in_dp, |
| 435 | + ) |
| 436 | + |
| 437 | + def write_to_shm(self, req_manager): |
| 438 | + """ |
| 439 | + 将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。 |
| 440 | + """ |
| 441 | + if kv_trans_use_p2p(): |
| 442 | + from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor |
| 443 | + |
| 444 | + mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__ |
| 445 | + |
| 446 | + from lightllm.common.req_manager import ReqManager |
| 447 | + |
| 448 | + req_manager: ReqManager = req_manager |
| 449 | + |
| 450 | + # 这个地方是一个不太优雅的设计,但是暂时这么做,可以让dp shared kv swap模块直接访问 req_manager 中的 req_to_token_indexs |
| 451 | + # 避免过多无用的数据复制和传输开销。 |
| 452 | + self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs |
| 453 | + |
| 454 | + lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock") |
| 455 | + with lock: |
| 456 | + node_world_size = get_node_world_size() |
| 457 | + shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}" |
| 458 | + obj_bytes_array = [ForkingPickler.dumps(self).tobytes() for _ in range(node_world_size * 2)] |
| 459 | + obj_size = len(obj_bytes_array[0]) |
| 460 | + shm = create_or_link_shm( |
| 461 | + name=shm_name, expected_size=obj_size * (node_world_size * 2) + 4 + 4, force_mode="create" |
| 462 | + ) |
| 463 | + logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer") |
| 464 | + shm.buf[0:4] = (node_world_size * 2).to_bytes(4, "little") |
| 465 | + shm.buf[4:8] = obj_size.to_bytes(4, "little") |
| 466 | + start_index = 8 |
| 467 | + for obj_bytes in obj_bytes_array: |
| 468 | + shm.buf[start_index : start_index + obj_size] = obj_bytes |
| 469 | + start_index += obj_size |
| 470 | + |
| 471 | + @staticmethod |
| 472 | + def loads_from_shm(rank_in_node: int) -> "MemoryManager": |
| 473 | + shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}" |
| 474 | + lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock") |
| 475 | + logger.info(f"get memmanager from shm {shm_name}") |
| 476 | + with lock: |
| 477 | + shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link") |
| 478 | + left_num = int.from_bytes(shm.buf[0:4], "little") |
| 479 | + obj_size = int.from_bytes(shm.buf[4:8], "little") |
| 480 | + assert left_num > 0 |
| 481 | + end_index = 8 + left_num * obj_size |
| 482 | + start_index = 8 + (left_num - 1) * obj_size |
| 483 | + obj_bytes = shm.buf[start_index:end_index].tobytes() |
| 484 | + shm.buf[0:4] = (left_num - 1).to_bytes(4, byteorder="little") |
| 485 | + shm.close() |
| 486 | + return ForkingPickler.loads(obj_bytes) |
| 487 | + |
404 | 488 |
|
405 | 489 | class ReadOnlyStaticsMemoryManager: |
406 | 490 | """ |
|
0 commit comments