diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index cc443bed4..542c6a485 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -72,6 +72,7 @@ def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id): model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '') del model_input['text'] del model_input['text_len'] + model_input['embedding'] = model_input['llm_embedding'] self.frontend.spk2info[zero_shot_spk_id] = model_input return True diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index f98b0d612..4c6d06d11 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -198,8 +198,15 @@ def frontend_instruct(self, tts_text, spk_id, instruct_text): def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id): model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id) - del model_input['llm_prompt_speech_token'] - del model_input['llm_prompt_speech_token_len'] + if bool(zero_shot_spk_id): + prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>') + model_input['prompt_text'] = prompt_text_token + model_input['prompt_text_len'] = prompt_text_token_len + if 'llm_prompt_speech_token' in model_input.keys(): + del model_input['llm_prompt_speech_token'] + if 'llm_prompt_speech_token_len' in model_input.keys(): + del model_input['llm_prompt_speech_token_len'] + return model_input def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):