Skip to content

Commit a4ff3c6

Browse files
authored
support-whisper-longaudio (#1128)
1 parent b2649ed commit a4ff3c6

File tree

2 files changed

+76
-18
lines changed

2 files changed

+76
-18
lines changed

lightllm/models/internvl/model.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight
2121
from lightllm.models.vit import get_image_patch_func
22+
from lightllm.models.whisper.defaults import MIN_AUDIO_LEN
2223

2324
IMG_START_TOKEN = "<img>"
2425
IMG_END_TOKEN = "</img>"
@@ -47,6 +48,9 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4748
self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag)
4849
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4950

51+
self.audio_min_length = MIN_AUDIO_LEN
52+
self.audio_max_length = 16000 * 30
53+
5054
def init_imageitem_extral_params(
5155
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
5256
):
@@ -81,16 +85,35 @@ def get_image_token_length(self, img: ImageItem):
8185

8286
def get_audio_token_length(self, audio: AudioItem):
8387
L = audio.audio_length
84-
L = L if L <= 480000 else 480000 # max_length < 30s
85-
mel_len = L // 160
86-
dilation = 1
87-
L_in = mel_len
88-
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
89-
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
90-
L_out = 1 + L_out // stride
91-
L_in = L_out
92-
audio_len_after_cnn = L_out
93-
audio_token_num = (audio_len_after_cnn - 2) // 2 + 1
88+
audio_token_num = 0
89+
chunk_lens = []
90+
if L <= self.audio_max_length:
91+
cur_len = L
92+
if cur_len < self.audio_min_length:
93+
cur_len = self.audio_min_length
94+
chunk_lens.append(cur_len)
95+
else:
96+
start = 0
97+
while start < L:
98+
end = min(start + self.audio_max_length, L)
99+
cur_len = end - start
100+
101+
if cur_len < self.audio_min_length:
102+
cur_len = self.audio_min_length
103+
104+
chunk_lens.append(cur_len)
105+
start = end
106+
for chunk_len in chunk_lens:
107+
mel_len = chunk_len // 160
108+
dilation = 1
109+
L_in = mel_len
110+
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
111+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
112+
L_out = 1 + L_out // stride
113+
L_in = L_out
114+
audio_len_after_cnn = L_out
115+
chunk_token_num = (audio_len_after_cnn - 2) // 2 + 1
116+
audio_token_num += int(chunk_token_num)
94117
return audio_token_num
95118

96119
# only change the impl of the encode func:

lightllm/models/whisper/whisper_audio.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,12 @@ def forward(self, audio_values, audio_lens_after_cnn):
162162
return x
163163

164164
def encode(self, audio_items: List[AudioItem]):
165+
# 每个元素是一个chunk
165166
batch_audios = []
166-
batch_audio_lens = np.zeros(len(audio_items), dtype=np.int32)
167+
batch_audio_lens = []
167168
uuids = []
169+
# 记录每个chunk属于哪个audio_items下标
170+
chunk_owner_index = []
168171
for i, item in enumerate(audio_items):
169172
if isinstance(item, AudioItem):
170173
uuids.append(item.uuid)
@@ -180,8 +183,25 @@ def encode(self, audio_items: List[AudioItem]):
180183
if audio.shape[0] < MIN_AUDIO_LEN:
181184
audio = np.pad(audio, (0, MIN_AUDIO_LEN - len(audio)), mode="constant", constant_values=0.0)
182185

183-
batch_audio_lens[i] = min(audio.shape[0], self.max_length)
184-
batch_audios.append(audio)
186+
if audio.shape[0] > self.max_length:
187+
start = 0
188+
while start < audio.shape[0]:
189+
end = min(start + self.max_length, audio.shape[0])
190+
chunk = audio[start:end]
191+
192+
if chunk.shape[0] < MIN_AUDIO_LEN:
193+
chunk = np.pad(chunk, (0, MIN_AUDIO_LEN - chunk.shape[0]), mode="constant", constant_values=0.0)
194+
batch_audios.append(chunk)
195+
batch_audio_lens.append(min(chunk.shape[0], self.max_length))
196+
chunk_owner_index.append(i)
197+
198+
start = end
199+
else:
200+
batch_audio_lens.append(min(audio.shape[0], self.max_length))
201+
batch_audios.append(audio)
202+
chunk_owner_index.append(i)
203+
204+
batch_audio_lens = np.array(batch_audio_lens, dtype=np.int32)
185205

186206
audios, audio_lens_after_cnn = self.audio_processor(
187207
batch_audios, batch_audio_lens, sampling_rate=16000, return_tensors="pt"
@@ -190,13 +210,28 @@ def encode(self, audio_items: List[AudioItem]):
190210
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
191211
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
192212

213+
num_audios = len(audio_items)
214+
per_audio_embeds = [[] for _ in range(num_audios)]
215+
216+
for chunk_idx, owner in enumerate(chunk_owner_index):
217+
token_len = int(audio_token_num[chunk_idx])
218+
if token_len <= 0:
219+
continue
220+
per_audio_embeds[owner].append(audios[chunk_idx][:token_len])
221+
193222
ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
194223
ids_to_set = []
195224
for i, ready in enumerate(ready_audio):
196-
if not ready:
197-
uid = uuids[i]
198-
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
199-
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
200-
ids_to_set.append(uid)
225+
if ready:
226+
continue
227+
228+
uid = uuids[i]
229+
230+
# 拼接该 audio 的所有 chunk embedding
231+
cur_embed = torch.cat(per_audio_embeds[i], dim=0)
232+
cur_embed_bytes = tensor2bytes(cur_embed)
233+
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
234+
ids_to_set.append(uid)
235+
201236
if ids_to_set:
202237
self.cache_client.root.set_items_embed(ids=ids_to_set)

0 commit comments

Comments
 (0)