Skip to content

Commit 497441f

Browse files
committed
Update TextToAudioPipeline types and default
1 parent 11a6bfc commit 497441f

File tree

2 files changed

+14
-41
lines changed

2 files changed

+14
-41
lines changed

src/pipelines.js

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,8 @@ const SUPPORTED_TASKS = Object.freeze({
210210
model: [AutoModelForTextToWaveform, AutoModelForTextToSpectrogram],
211211
processor: [AutoProcessor, /* Some don't use a processor */ null],
212212
default: {
213-
// TODO: replace with original
214-
// "model": "microsoft/speecht5_tts",
215-
model: 'Xenova/speecht5_tts',
213+
model: 'onnx-community/Supertonic-TTS-ONNX',
214+
dtype: 'fp32',
216215
},
217216
type: 'text',
218217
},

src/pipelines/text-to-audio.js

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ import { AutoModel } from '../models.js';
1717
*/
1818

1919
/**
20-
* @typedef {Object} TextToAudioOutput
21-
* @property {Float32Array} audio The generated audio waveform.
22-
* @property {number} sampling_rate The sampling rate of the generated audio waveform.
20+
* @typedef {RawAudio} TextToAudioOutput
2321
*
2422
* @typedef {Object} TextToAudioPipelineOptions Parameters specific to text-to-audio pipelines.
2523
* @property {Tensor|Float32Array|string|URL} [speaker_embeddings=null] The speaker embeddings (if the model requires it).
@@ -50,7 +48,7 @@ import { AutoModel } from '../models.js';
5048
* // audio: Float32Array(95232) [-0.000482565927086398, -0.0004853440332226455, ...],
5149
* // sampling_rate: 44100
5250
* // }
53-
*
51+
*
5452
* // Optional: Save the audio to a .wav file or Blob
5553
* await output.save('output.wav'); // You can also use `output.toBlob()` to access the audio as a Blob
5654
* ```
@@ -87,37 +85,24 @@ export class TextToAudioPipeline
8785
// Load speaker embeddings as Float32Array from path/URL
8886
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
8987
// Load from URL with fetch
90-
speaker_embeddings = new Float32Array(
91-
await (await fetch(speaker_embeddings)).arrayBuffer()
92-
);
88+
speaker_embeddings = new Float32Array(await (await fetch(speaker_embeddings)).arrayBuffer());
9389
}
9490

9591
if (speaker_embeddings instanceof Float32Array) {
96-
speaker_embeddings = new Tensor(
97-
'float32',
98-
speaker_embeddings,
99-
[speaker_embeddings.length]
100-
)
92+
speaker_embeddings = new Tensor('float32', speaker_embeddings, [speaker_embeddings.length]);
10193
} else if (!(speaker_embeddings instanceof Tensor)) {
102-
throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.")
94+
throw new Error('Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.');
10395
}
10496

10597
return speaker_embeddings;
10698
}
10799

108100
/** @type {TextToAudioPipelineCallback} */
109-
async _call(text_inputs, {
110-
speaker_embeddings = null,
111-
num_inference_steps,
112-
speed,
113-
} = {}) {
114-
101+
async _call(text_inputs, { speaker_embeddings = null, num_inference_steps, speed } = {}) {
115102
// If this.processor is not set, we are using a `AutoModelForTextToWaveform` model
116103
if (this.processor) {
117104
return this._call_text_to_spectrogram(text_inputs, { speaker_embeddings });
118-
} else if (
119-
this.model.config.model_type === "supertonic"
120-
) {
105+
} else if (this.model.config.model_type === 'supertonic') {
121106
return this._call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed });
122107
} else {
123108
return this._call_text_to_waveform(text_inputs);
@@ -126,14 +111,14 @@ export class TextToAudioPipeline
126111

127112
async _call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed }) {
128113
if (!speaker_embeddings) {
129-
throw new Error("Speaker embeddings must be provided for Supertonic models.");
114+
throw new Error('Speaker embeddings must be provided for Supertonic models.');
130115
}
131116
speaker_embeddings = await this._prepare_speaker_embeddings(speaker_embeddings);
132117

133118
// @ts-expect-error TS2339
134119
const { sampling_rate, style_dim } = this.model.config;
135120

136-
speaker_embeddings = (/** @type {Tensor} */ (speaker_embeddings)).view(1, -1, style_dim);
121+
speaker_embeddings = /** @type {Tensor} */ (speaker_embeddings).view(1, -1, style_dim);
137122
const inputs = this.tokenizer(text_inputs, {
138123
padding: true,
139124
truncation: true,
@@ -147,14 +132,10 @@ export class TextToAudioPipeline
147132
speed,
148133
});
149134

150-
return new RawAudio(
151-
waveform.data,
152-
sampling_rate,
153-
)
135+
return new RawAudio(waveform.data, sampling_rate);
154136
}
155137

156138
async _call_text_to_waveform(text_inputs) {
157-
158139
// Run tokenization
159140
const inputs = this.tokenizer(text_inputs, {
160141
padding: true,
@@ -166,14 +147,10 @@ export class TextToAudioPipeline
166147

167148
// @ts-expect-error TS2339
168149
const sampling_rate = this.model.config.sampling_rate;
169-
return new RawAudio(
170-
waveform.data,
171-
sampling_rate,
172-
)
150+
return new RawAudio(waveform.data, sampling_rate);
173151
}
174152

175153
async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {
176-
177154
// Load vocoder, if not provided
178155
if (!this.vocoder) {
179156
console.log('No vocoder specified, using default HifiGan vocoder.');
@@ -193,9 +170,6 @@ export class TextToAudioPipeline
193170
const { waveform } = await this.model.generate_speech(input_ids, speaker_embeddings, { vocoder: this.vocoder });
194171

195172
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
196-
return new RawAudio(
197-
waveform.data,
198-
sampling_rate,
199-
)
173+
return new RawAudio(waveform.data, sampling_rate);
200174
}
201175
}

0 commit comments

Comments
 (0)