@@ -19,53 +19,70 @@ import { softmax } from '../utils/maths.js';
1919 * @typedef {Object } FillMaskPipelineOptions Parameters specific to fill mask pipelines.
2020 * @property {number } [top_k=5] When passed, overrides the number of predictions to return.
2121 *
22- * @callback FillMaskPipelineCallback Fill the masked token in the text(s) given as inputs .
23- * @param {string|string[] } texts One or several texts (or one list of prompts) with masked tokens.
22+ * @callback FillMaskPipelineCallbackSingle Fill the masked token in the text given as input .
23+ * @param {string } texts The text with masked tokens.
2424 * @param {FillMaskPipelineOptions } [options] The options to use for masked language modelling.
25- * @returns {Promise<FillMaskOutput|FillMaskOutput[]> } An array of objects containing the score, predicted token, predicted token string,
26- * and the sequence with the predicted token filled in, or an array of such arrays (one for each input text).
27- * If only one input text is given, the output will be an array of objects.
28- * @throws {Error } When the mask token is not found in the input text.
25+ * @returns {Promise<FillMaskOutput> } An array of objects containing the score, predicted token, predicted token string,
26+ * and the sequence with the predicted token filled in.
27+ *
28+ * @callback FillMaskPipelineCallbackBatch Fill the masked token in the texts given as inputs.
29+ * @param {string[] } texts A list of texts with masked tokens.
30+ * @param {FillMaskPipelineOptions } [options] The options to use for masked language modelling.
31+ * @returns {Promise<FillMaskOutput[]> } An array where each entry corresponds to the predictions for an input text.
32+ *
33+ * @typedef {FillMaskPipelineCallbackSingle & FillMaskPipelineCallbackBatch } FillMaskPipelineCallback
2934 *
3035 * @typedef {TextPipelineConstructorArgs & FillMaskPipelineCallback & Disposable } FillMaskPipelineType
3136 */
3237
3338/**
3439 * Masked language modeling prediction pipeline using any `ModelWithLMHead`.
3540 *
41+ * **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `onnx-community/ettin-encoder-32m-ONNX`.
42+ * ```javascript
43+ * import { pipeline } from '@huggingface/transformers';
44+ *
45+ * const unmasker = await pipeline('fill-mask', 'onnx-community/ettin-encoder-32m-ONNX');
46+ * const output = await unmasker('The capital of France is [MASK].');
47+ * // [
48+ * // { score: 0.5151872038841248, token: 7785, token_str: ' Paris', sequence: 'The capital of France is Paris.' },
49+ * // { score: 0.033725105226039886, token: 42268, token_str: ' Lyon', sequence: 'The capital of France is Lyon.' },
50+ * // { score: 0.031234024092555046, token: 23397, token_str: ' Nancy', sequence: 'The capital of France is Nancy.' },
51+ * // { score: 0.02075139433145523, token: 30167, token_str: ' Brussels', sequence: 'The capital of France is Brussels.' },
52+ * // { score: 0.018962178379297256, token: 31955, token_str: ' Geneva', sequence: 'The capital of France is Geneva.' }
53+ * // ]
54+ * ```
55+ *
3656 * **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `Xenova/bert-base-uncased`.
3757 * ```javascript
58+ * import { pipeline } from '@huggingface/transformers';
59+ *
3860 * const unmasker = await pipeline('fill-mask', 'Xenova/bert-base-cased');
3961 * const output = await unmasker('The goal of life is [MASK].');
4062 * // [
41- * // { token_str: 'survival', score: 0.06137419492006302, token: 8115, sequence: ' The goal of life is survival.' },
42- * // { token_str: 'love', score: 0.03902450203895569, token: 1567, sequence: ' The goal of life is love.' },
43- * // { token_str: 'happiness', score: 0.03253183513879776, token: 9266, sequence: ' The goal of life is happiness.' },
44- * // { token_str: 'freedom', score: 0.018736306577920914, token: 4438, sequence: ' The goal of life is freedom.' },
45- * // { token_str: 'life', score: 0.01859794743359089, token: 1297, sequence: ' The goal of life is life.' }
63+ * // { score: 0.11368396878242493, sequence: " The goal of life is survival.", token: 8115, token_str: "survival" },
64+ * // { score: 0.053510840982198715, sequence: " The goal of life is love.", token: 1567, token_str: "love" },
65+ * // { score: 0.05041185021400452, sequence: " The goal of life is happiness.", token: 9266, token_str: "happiness" },
66+ * // { score: 0.033218126744031906, sequence: " The goal of life is freedom.", token: 4438, token_str: "freedom" },
67+ * // { score: 0.03301157429814339, sequence: " The goal of life is success.", token: 2244, token_str: "success" },
4668 * // ]
4769 * ```
4870 *
4971 * **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `Xenova/bert-base-cased` (and return top result).
5072 * ```javascript
73+ * import { pipeline } from '@huggingface/transformers';
74+ *
5175 * const unmasker = await pipeline('fill-mask', 'Xenova/bert-base-cased');
5276 * const output = await unmasker('The Milky Way is a [MASK] galaxy.', { top_k: 1 });
53- * // [{ token_str: 'spiral', score: 0.6299987435340881, token: 14061, sequence: ' The Milky Way is a spiral galaxy.' }]
77+ * // [{ score: 0.5982972383499146, sequence: " The Milky Way is a spiral galaxy.", token: 14061, token_str: "spiral" }]
5478 * ```
5579 */
5680export class FillMaskPipeline
5781 extends /** @type {new (options: TextPipelineConstructorArgs) => FillMaskPipelineType } */ ( Pipeline )
5882{
59- /**
60- * Create a new FillMaskPipeline.
61- * @param {TextPipelineConstructorArgs } options An object used to instantiate the pipeline.
62- */
63- constructor ( options ) {
64- super ( options ) ;
65- }
66-
67- /** @type {FillMaskPipelineCallback } */
6883 async _call ( texts , { top_k = 5 } = { } ) {
84+ const { mask_token_id, mask_token } = this . tokenizer ;
85+
6986 // Run tokenization
7087 const model_inputs = this . tokenizer ( texts , {
7188 padding : true ,
@@ -84,11 +101,11 @@ export class FillMaskPipeline
84101 const mask_token_index = ids . findIndex (
85102 ( x ) =>
86103 // We use == to match bigint with number
87- // @ts -ignore
88- x == this . tokenizer . mask_token_id ,
104+ // @ts -expect-error TS2367
105+ x == mask_token_id ,
89106 ) ;
90107 if ( mask_token_index === - 1 ) {
91- throw Error ( `Mask token (${ this . tokenizer . mask_token } ) not found in text.` ) ;
108+ throw Error ( `Mask token (${ mask_token } ) not found in text.` ) ;
92109 }
93110 const itemLogits = logits [ i ] [ mask_token_index ] ;
94111
0 commit comments