Skip to content

Commit 998e083

Browse files
SangChengCsangchengmenghiworldwzj
authored
[support] vit fa support cu_seqlens and max_seqlens (#953)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Co-authored-by: wangzaijun <wzjhelloworld@qq.com>
1 parent 81c5f61 commit 998e083

File tree

6 files changed

+142
-207
lines changed

6 files changed

+142
-207
lines changed

lightllm/common/basemodel/layer_infer/cache_tensor_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def __init__(self):
9393
self.cuda_graph_cur_batch_size = None
9494
self.is_cuda_graph = False
9595
self.managed_total_tensor_bytes = 0
96+
# 防止误用导致显存泄露,添加标记变量。
97+
# 当使用者没有合法的调用 cache_env_in 和 cache_env_out 的时候
98+
# 如果调用了alloc_tensor 接口,则退化为 torch.empty 申请方式。
99+
self.cache_env_ok = False
96100

97101
def cache_env_in(
98102
self, is_cuda_graph: bool = False, cur_batch_size: int = 0, cuda_graph_max_batch_size: int = 0
@@ -107,6 +111,7 @@ def cache_env_in(
107111
assert self.inner_cuda_graph_manager.cuda_graph_max_batch_size == cuda_graph_max_batch_size
108112
self.cuda_graph_cur_batch_size = cur_batch_size
109113
assert cur_batch_size != 0
114+
self.cache_env_ok = True
110115
return
111116

112117
def cache_env_out(self):
@@ -115,6 +120,7 @@ def cache_env_out(self):
115120
self.free_shape_dtype_to_bufs.clear()
116121
self.calcu_shape_cache.clear()
117122
self.changed_ptr.clear()
123+
self.cache_env_ok = False
118124
return
119125

120126
def alloc_tensor(
@@ -129,6 +135,11 @@ def alloc_tensor(
129135
# shape 类型转换
130136
if isinstance(shape, list):
131137
shape = torch.Size(shape)
138+
139+
# cache manager 没有被正常使用时
140+
if not self.cache_env_ok:
141+
return torch.empty(shape, dtype=data_type, device=device, requires_grad=False)
142+
132143
# 是 cuda graph的时候,由cuda graph manager 接管
133144
if self.is_cuda_graph:
134145
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 8 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from transformers.utils import TensorType
2323
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
2424
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
25+
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
26+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2527

2628
# adapted from
2729
# https://github.com/huggingface/transformers/blob/
@@ -123,7 +125,7 @@ def apply_rotary_pos_emb_vision(
123125
return q_embed, k_embed
124126

125127

126-
class Qwen2_5_VLVisionAttention(nn.Module):
128+
class Qwen2_5_VLVisionFlashAttention(nn.Module):
127129
def __init__(self, dim: int, num_heads: int = 16) -> None:
128130
super().__init__()
129131
self.num_heads = num_heads
@@ -148,94 +150,28 @@ def forward(
148150
cos, sin = position_embeddings
149151
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
150152

151-
attention_mask = torch.full(
152-
[1, seq_length, seq_length],
153-
torch.finfo(q.dtype).min,
154-
device=q.device,
155-
dtype=q.dtype,
156-
)
157-
for i in range(1, len(cu_seqlens)):
158-
attention_mask[
159-
...,
160-
cu_seqlens[i - 1] : cu_seqlens[i],
161-
cu_seqlens[i - 1] : cu_seqlens[i],
162-
] = 0
163-
164-
q = q.transpose(0, 1)
165-
k = k.transpose(0, 1)
166-
v = v.transpose(0, 1)
167-
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
168-
attn_weights = attn_weights + attention_mask
169-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
170-
attn_output = torch.matmul(attn_weights, v)
171-
attn_output = attn_output.transpose(0, 1)
172-
attn_output = attn_output.reshape(seq_length, -1)
173-
attn_output = self.proj(attn_output)
174-
return attn_output
175-
176-
177-
class Qwen2_5_VLVisionSdpaAttention(nn.Module):
178-
def __init__(self, dim: int, num_heads: int = 16) -> None:
179-
super().__init__()
180-
self.num_heads = num_heads
181-
self.qkv = nn.Linear(dim, dim * 3, bias=True)
182-
self.proj = nn.Linear(dim, dim)
183-
184-
def forward(
185-
self,
186-
hidden_states: torch.Tensor,
187-
cu_seqlens: torch.Tensor,
188-
rotary_pos_emb: Optional[torch.Tensor] = None,
189-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
190-
) -> torch.Tensor:
191-
seq_length = hidden_states.shape[0]
192-
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
193-
if position_embeddings is None:
194-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
195-
cos = emb.cos()
196-
sin = emb.sin()
197-
else:
198-
cos, sin = position_embeddings
199-
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
200-
201-
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
202-
for i in range(1, len(cu_seqlens)):
203-
attention_mask[
204-
...,
205-
cu_seqlens[i - 1] : cu_seqlens[i],
206-
cu_seqlens[i - 1] : cu_seqlens[i],
207-
] = True
208-
q = q.transpose(0, 1)
209-
k = k.transpose(0, 1)
210-
v = v.transpose(0, 1)
211-
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
212-
attn_output = attn_output.transpose(0, 1)
153+
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
154+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
155+
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
156+
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
213157
attn_output = attn_output.reshape(seq_length, -1)
214158
attn_output = self.proj(attn_output)
215159
return attn_output
216160

217161

218-
QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
219-
"eager": Qwen2_5_VLVisionAttention,
220-
# "flash_attention_2": Qwen2_5_VLVisionFlashAttention2,
221-
"sdpa": Qwen2_5_VLVisionSdpaAttention,
222-
}
223-
224-
225162
class Qwen2_5_VLVisionBlock(nn.Module):
226163
def __init__(
227164
self,
228165
hidden_size,
229166
intermediate_size,
230167
num_heads,
231168
hidden_act,
232-
attn_implementation: str = "eager",
233169
) -> None:
234170
super().__init__()
235171
self.norm1 = Qwen2RMSNorm(hidden_size, eps=1e-6)
236172
self.norm2 = Qwen2RMSNorm(hidden_size, eps=1e-6)
237173

238-
self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](hidden_size, num_heads=num_heads)
174+
self.attn = Qwen2_5_VLVisionFlashAttention(hidden_size, num_heads=num_heads)
239175
self.mlp = Qwen2_5_VLMLP(
240176
hidden_size=hidden_size,
241177
intermediate_size=intermediate_size,
@@ -312,8 +248,6 @@ def __init__(
312248

313249
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
314250

315-
self.attn_implementation = "eager"
316-
317251
self.patch_embed = PatchEmbed(
318252
patch_size=self.patch_size,
319253
temporal_patch_size=self.temporal_patch_size,
@@ -331,7 +265,6 @@ def __init__(
331265
self.intermediate_size,
332266
self.num_heads,
333267
self.hidden_act,
334-
self.attn_implementation,
335268
)
336269
for _ in range(self.depth)
337270
]

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
from transformers.utils import TensorType
4444
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
4545
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
46-
46+
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
47+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
4748

4849
from transformers.utils import is_flash_attn_2_available
4950

@@ -210,7 +211,7 @@ def forward(
210211

211212
# adapted from
212213
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
213-
class VisionFlashAttention2(nn.Module):
214+
class VisionFlashAttention(nn.Module):
214215
def __init__(self, dim: int, num_heads: int = 16) -> None:
215216
super().__init__()
216217
self.num_heads = num_heads
@@ -222,63 +223,31 @@ def forward(
222223
) -> torch.Tensor:
223224
seq_length = hidden_states.shape[0]
224225
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
225-
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
226-
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
226+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
227+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
228+
q = q.squeeze(0)
229+
k = k.squeeze(0)
227230

231+
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
228232
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
229-
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
230-
seq_length, -1
231-
)
232-
attn_output = self.proj(attn_output)
233-
return attn_output
234-
235-
236-
# adapted from
237-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
238-
class VisionSdpaAttention(nn.Module):
239-
def __init__(self, dim: int, num_heads: int = 16) -> None:
240-
super().__init__()
241-
self.num_heads = num_heads
242-
self.qkv = nn.Linear(dim, dim * 3, bias=True)
243-
self.proj = nn.Linear(dim, dim)
244-
245-
def forward(
246-
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
247-
) -> torch.Tensor:
248-
seq_length = hidden_states.shape[0]
249-
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
250-
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
251-
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
233+
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
252234

253-
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
254-
for i in range(1, len(cu_seqlens)):
255-
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
256-
q = q.transpose(0, 1)
257-
k = k.transpose(0, 1)
258-
v = v.transpose(0, 1)
259-
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
260-
attn_output = attn_output.transpose(0, 1)
235+
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
261236
attn_output = attn_output.reshape(seq_length, -1)
262237
attn_output = self.proj(attn_output)
263238
return attn_output
264239

265240

266-
QWEN2_VL_VISION_ATTENTION_CLASSES = {
267-
"eager": VisionAttention,
268-
# "flash_attention_2": VisionFlashAttention2,
269-
"sdpa": VisionSdpaAttention,
270-
}
271-
272241
# adapted from
273242
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
274243
class Qwen2VLVisionBlock(nn.Module):
275-
def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act, attn_implementation: str = "eager") -> None:
244+
def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None:
276245
super().__init__()
277246
self.norm1 = LayerNorm(embed_dim, eps=1e-6)
278247
self.norm2 = LayerNorm(embed_dim, eps=1e-6)
279248
mlp_hidden_dim = int(embed_dim * mlp_ratio)
280249

281-
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](embed_dim, num_heads=num_heads)
250+
self.attn = VisionFlashAttention(embed_dim, num_heads=num_heads)
282251
self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act)
283252

284253
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
@@ -318,8 +287,6 @@ def __init__(
318287
self.spatial_merge_size = spatial_merge_size
319288
self.temporal_patch_size = temporal_patch_size
320289

321-
self.attn_implementation = "eager"
322-
323290
self.patch_embed = PatchEmbed(
324291
patch_size=self.patch_size,
325292
temporal_patch_size=self.temporal_patch_size,
@@ -332,9 +299,7 @@ def __init__(
332299

333300
self.blocks = nn.ModuleList(
334301
[
335-
Qwen2VLVisionBlock(
336-
self.embed_dim, self.mlp_ratio, self.num_heads, self.hidden_act, self.attn_implementation
337-
)
302+
Qwen2VLVisionBlock(self.embed_dim, self.mlp_ratio, self.num_heads, self.hidden_act)
338303
for _ in range(self.depth)
339304
]
340305
)

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import torch
2-
import torch.functional as F
32
import torch.distributed as dist
4-
import numpy as np
5-
from typing import Tuple
6-
from functools import partial
7-
import triton
3+
84

95
from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
10-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
116
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
127
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
138
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
@@ -103,9 +98,13 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens
10398

10499
def _context_attention_kernel(self, q, k, v) -> torch.Tensor:
105100
out = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
106-
batch_size = q.shape[0]
107-
seq_len = q.shape[1]
108-
flash_attention_fwd(q, k, v, out)
101+
batch_size, seq_len, head_num, head_dim = q.shape
102+
total_len = batch_size * seq_len
103+
reshape = lambda t: t.view(total_len, head_num, head_dim)
104+
q, k, v, out = map(reshape, (q, k, v, out))
105+
cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len
106+
max_seqlen = seq_len
107+
flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen)
109108
return out.reshape(batch_size, seq_len, -1)
110109

111110
def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:

0 commit comments

Comments
 (0)