@@ -162,9 +162,12 @@ def forward(self, audio_values, audio_lens_after_cnn):
162162 return x
163163
164164 def encode (self , audio_items : List [AudioItem ]):
165+ # 每个元素是一个chunk
165166 batch_audios = []
166- batch_audio_lens = np . zeros ( len ( audio_items ), dtype = np . int32 )
167+ batch_audio_lens = []
167168 uuids = []
169+ # 记录每个chunk属于哪个audio_items下标
170+ chunk_owner_index = []
168171 for i , item in enumerate (audio_items ):
169172 if isinstance (item , AudioItem ):
170173 uuids .append (item .uuid )
@@ -180,8 +183,25 @@ def encode(self, audio_items: List[AudioItem]):
180183 if audio .shape [0 ] < MIN_AUDIO_LEN :
181184 audio = np .pad (audio , (0 , MIN_AUDIO_LEN - len (audio )), mode = "constant" , constant_values = 0.0 )
182185
183- batch_audio_lens [i ] = min (audio .shape [0 ], self .max_length )
184- batch_audios .append (audio )
186+ if audio .shape [0 ] > self .max_length :
187+ start = 0
188+ while start < audio .shape [0 ]:
189+ end = min (start + self .max_length , audio .shape [0 ])
190+ chunk = audio [start :end ]
191+
192+ if chunk .shape [0 ] < MIN_AUDIO_LEN :
193+ chunk = np .pad (chunk , (0 , MIN_AUDIO_LEN - chunk .shape [0 ]), mode = "constant" , constant_values = 0.0 )
194+ batch_audios .append (chunk )
195+ batch_audio_lens .append (min (chunk .shape [0 ], self .max_length ))
196+ chunk_owner_index .append (i )
197+
198+ start = end
199+ else :
200+ batch_audio_lens .append (min (audio .shape [0 ], self .max_length ))
201+ batch_audios .append (audio )
202+ chunk_owner_index .append (i )
203+
204+ batch_audio_lens = np .array (batch_audio_lens , dtype = np .int32 )
185205
186206 audios , audio_lens_after_cnn = self .audio_processor (
187207 batch_audios , batch_audio_lens , sampling_rate = 16000 , return_tensors = "pt"
@@ -190,13 +210,28 @@ def encode(self, audio_items: List[AudioItem]):
190210 audio_lens_after_cnn = np .array (audio_lens_after_cnn , dtype = np .int32 )
191211 audio_token_num = (audio_lens_after_cnn - 2 ) // 2 + 1
192212
213+ num_audios = len (audio_items )
214+ per_audio_embeds = [[] for _ in range (num_audios )]
215+
216+ for chunk_idx , owner in enumerate (chunk_owner_index ):
217+ token_len = int (audio_token_num [chunk_idx ])
218+ if token_len <= 0 :
219+ continue
220+ per_audio_embeds [owner ].append (audios [chunk_idx ][:token_len ])
221+
193222 ready_audio = obtain (self .cache_client .root .get_items_embed (uuids ))
194223 ids_to_set = []
195224 for i , ready in enumerate (ready_audio ):
196- if not ready :
197- uid = uuids [i ]
198- cur_embed_bytes = tensor2bytes (audios [i ][: audio_token_num [i ]])
199- create_shm (get_shm_name_embed (uid ), cur_embed_bytes )
200- ids_to_set .append (uid )
225+ if ready :
226+ continue
227+
228+ uid = uuids [i ]
229+
230+ # 拼接该 audio 的所有 chunk embedding
231+ cur_embed = torch .cat (per_audio_embeds [i ], dim = 0 )
232+ cur_embed_bytes = tensor2bytes (cur_embed )
233+ create_shm (get_shm_name_embed (uid ), cur_embed_bytes )
234+ ids_to_set .append (uid )
235+
201236 if ids_to_set :
202237 self .cache_client .root .set_items_embed (ids = ids_to_set )
0 commit comments