|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +from .mem_manager import MemoryManager |
| 4 | +from typing import List, Union |
| 5 | +from lightllm.utils.log_utils import init_logger |
| 6 | +from lightllm.utils.envs_utils import get_page_size |
| 7 | + |
| 8 | + |
| 9 | +def cdiv(a, b): |
| 10 | + return (a + b - 1) // b |
| 11 | + |
| 12 | + |
| 13 | +logger = init_logger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +class PageSizeVariableMemoryManager(MemoryManager): |
| 17 | + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): |
| 18 | + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) |
| 19 | + self.req_to_page_indexs = None |
| 20 | + page_size = get_page_size() |
| 21 | + self.page_idx_pool = torch.arange( |
| 22 | + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True |
| 23 | + ) |
| 24 | + self.mark_page_start = 0 |
| 25 | + self.can_use_page_size = cdiv(self.size, page_size) |
| 26 | + |
| 27 | + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): |
| 28 | + self.kv_buffer = torch.empty( |
| 29 | + (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), |
| 30 | + dtype=dtype, |
| 31 | + device="cuda", |
| 32 | + ) |
| 33 | + |
| 34 | + # 要求长度必须是page_size的整数倍,page内token索引必须连续 |
| 35 | + def check_cache_page_valid(self, values: torch.Tensor): |
| 36 | + end = len(values) |
| 37 | + assert end % self.page_size == 0, "Values length must be a multiple of page size" |
| 38 | + total_pages = end // self.page_size |
| 39 | + for page_idx in range(total_pages): |
| 40 | + values_start = page_idx * self.page_size |
| 41 | + values_end = min((page_idx + 1) * self.page_size, end) |
| 42 | + page_token_idxs = values[values_start:values_end] |
| 43 | + if len(page_token_idxs) > 1: |
| 44 | + expected_idxs = torch.arange( |
| 45 | + page_token_idxs[0], |
| 46 | + page_token_idxs[0] + len(page_token_idxs), |
| 47 | + dtype=page_token_idxs.dtype, |
| 48 | + device=page_token_idxs.device, |
| 49 | + ) |
| 50 | + if not torch.equal(page_token_idxs, expected_idxs): |
| 51 | + return False |
| 52 | + return True |
| 53 | + |
| 54 | + def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): |
| 55 | + # assert self.check_cache_page_valid(values), "Values must be valid for page size" |
| 56 | + page_size = get_page_size() |
| 57 | + self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size |
| 58 | + self.req_to_token_indexs[req_idx, start:end] = values |
| 59 | + |
| 60 | + def expand_by_page_size(self, b_token_len, page_size): |
| 61 | + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 |
| 62 | + b_page_len = cdiv(b_token_len, page_size) |
| 63 | + need_pages_num = b_page_len.sum() |
| 64 | + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) |
| 65 | + cumsum_pages = torch.cumsum(b_page_len, dim=0) |
| 66 | + last_page_positions = cumsum_pages - 1 |
| 67 | + remainders = b_token_len - (b_page_len - 1) * page_size |
| 68 | + p_token_len[last_page_positions] = remainders |
| 69 | + return need_pages_num, b_page_len, p_token_len |
| 70 | + |
| 71 | + def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill): |
| 72 | + if is_prefill: |
| 73 | + b_req_idx = b_req_idx.cuda() |
| 74 | + b_seq_len = b_seq_len.cuda() |
| 75 | + b_ready_cache_len = b_ready_cache_len.cuda() |
| 76 | + |
| 77 | + b_token_len = b_seq_len - b_ready_cache_len |
| 78 | + total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size) |
| 79 | + if self.can_use_page_size < total_pages_needed: |
| 80 | + raise RuntimeError( |
| 81 | + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}" |
| 82 | + ) |
| 83 | + |
| 84 | + allocated_pages = self.page_idx_pool[ |
| 85 | + self.mark_page_start : self.mark_page_start + total_pages_needed |
| 86 | + ].cuda() |
| 87 | + |
| 88 | + def get_offsets_by_length(b_len, max_len): |
| 89 | + # 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4] |
| 90 | + offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device) |
| 91 | + offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1) |
| 92 | + return torch.masked_select(offsets, offset_mask) |
| 93 | + |
| 94 | + page_offsets = get_offsets_by_length(b_page_len, b_page_len.max()) |
| 95 | + token_offsets = get_offsets_by_length(p_token_len, page_size) |
| 96 | + |
| 97 | + # 更新req_to_page_indexs, b_ready_cache_len必整除page_size |
| 98 | + page_starts = b_ready_cache_len // page_size |
| 99 | + req_id = torch.repeat_interleave( |
| 100 | + torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len |
| 101 | + ) |
| 102 | + self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages |
| 103 | + |
| 104 | + self.mark_page_start += total_pages_needed |
| 105 | + self.can_use_page_size -= total_pages_needed |
| 106 | + page_bases = allocated_pages * page_size |
| 107 | + return torch.repeat_interleave(page_bases, p_token_len) + token_offsets |
| 108 | + else: |
| 109 | + b_seq_len = b_seq_len.cuda() |
| 110 | + b_req_idx = b_req_idx.cuda() |
| 111 | + need_new_page_mask = (b_seq_len - 1) % page_size == 0 |
| 112 | + new_pages_num = need_new_page_mask.sum() |
| 113 | + if self.can_use_page_size < new_pages_num: |
| 114 | + raise RuntimeError( |
| 115 | + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}" |
| 116 | + ) |
| 117 | + |
| 118 | + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) |
| 119 | + if new_pages_num > 0: |
| 120 | + new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda() |
| 121 | + self.mark_page_start += new_pages_num |
| 122 | + self.can_use_page_size -= new_pages_num |
| 123 | + token_idxs[need_new_page_mask] = new_pages * page_size |
| 124 | + |
| 125 | + # 需要更新req_to_page_indexs |
| 126 | + new_page_req_indices = b_req_idx[need_new_page_mask] |
| 127 | + page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size |
| 128 | + self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages |
| 129 | + |
| 130 | + mask = ~need_new_page_mask |
| 131 | + if mask.any(): |
| 132 | + seq_lens = b_seq_len[mask] |
| 133 | + token_idxs[mask] = ( |
| 134 | + self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size |
| 135 | + + (seq_lens - 1) % page_size |
| 136 | + ) |
| 137 | + return token_idxs |
| 138 | + |
| 139 | + def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor: |
| 140 | + page_size = get_page_size() |
| 141 | + token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill) |
| 142 | + self.can_use_mem_size -= need_size |
| 143 | + self.shared_can_use_token_num.set_value(self.can_use_mem_size) |
| 144 | + return token_idxs |
| 145 | + |
| 146 | + def free(self, free_index: Union[torch.Tensor, List[int]]): |
| 147 | + self.can_use_mem_size += len(free_index) |
| 148 | + self.shared_can_use_token_num.set_value(self.can_use_mem_size) |
| 149 | + |
| 150 | + page_size = get_page_size() |
| 151 | + if isinstance(free_index, list): |
| 152 | + free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True) |
| 153 | + |
| 154 | + if len(free_index) == 0: |
| 155 | + return |
| 156 | + |
| 157 | + page_indices = free_index // page_size |
| 158 | + unique_pages = torch.unique(page_indices) |
| 159 | + for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序 |
| 160 | + self.mark_page_start -= 1 |
| 161 | + self.page_idx_pool[self.mark_page_start] = page_idx |
| 162 | + self.can_use_page_size += 1 |
| 163 | + |
| 164 | + return |
| 165 | + |
| 166 | + def free_all(self): |
| 167 | + super().free_all() |
| 168 | + page_size = get_page_size() |
| 169 | + self.mark_page_start = 0 |
| 170 | + self.can_use_page_size = cdiv(self.size, page_size) |
| 171 | + self.page_idx_pool = torch.arange( |
| 172 | + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True |
| 173 | + ) |
0 commit comments