Skip to content

Commit e60e9da

Browse files
committed
Update FillMaskPipeline types and default
1 parent 47f2d77 commit e60e9da

File tree

2 files changed

+46
-30
lines changed

2 files changed

+46
-30
lines changed

src/pipelines.js

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@ const SUPPORTED_TASKS = Object.freeze({
109109
pipeline: FillMaskPipeline,
110110
model: AutoModelForMaskedLM,
111111
default: {
112-
// TODO: replace with original
113-
// "model": "bert-base-uncased",
114-
model: 'Xenova/bert-base-uncased',
112+
model: 'onnx-community/ettin-encoder-32m-ONNX',
113+
dtype: 'fp32',
115114
},
116115
type: 'text',
117116
},
@@ -445,9 +444,9 @@ export async function pipeline(
445444
if (!model) {
446445
model = pipelineInfo.default.model;
447446
console.log(`No model specified. Using default model: "${model}".`);
448-
}
449-
if (!dtype && pipelineInfo.default.dtype) {
450-
dtype = pipelineInfo.default.dtype;
447+
if (!dtype && pipelineInfo.default.dtype) {
448+
dtype = pipelineInfo.default.dtype;
449+
}
451450
}
452451

453452
const pretrainedOptions = {

src/pipelines/fill-mask.js

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,70 @@ import { softmax } from '../utils/maths.js';
1919
* @typedef {Object} FillMaskPipelineOptions Parameters specific to fill mask pipelines.
2020
* @property {number} [top_k=5] When passed, overrides the number of predictions to return.
2121
*
22-
* @callback FillMaskPipelineCallback Fill the masked token in the text(s) given as inputs.
23-
* @param {string|string[]} texts One or several texts (or one list of prompts) with masked tokens.
22+
* @callback FillMaskPipelineCallbackSingle Fill the masked token in the text given as input.
23+
* @param {string} texts The text with masked tokens.
2424
* @param {FillMaskPipelineOptions} [options] The options to use for masked language modelling.
25-
* @returns {Promise<FillMaskOutput|FillMaskOutput[]>} An array of objects containing the score, predicted token, predicted token string,
26-
* and the sequence with the predicted token filled in, or an array of such arrays (one for each input text).
27-
* If only one input text is given, the output will be an array of objects.
28-
* @throws {Error} When the mask token is not found in the input text.
25+
* @returns {Promise<FillMaskOutput>} An array of objects containing the score, predicted token, predicted token string,
26+
* and the sequence with the predicted token filled in.
27+
*
28+
* @callback FillMaskPipelineCallbackBatch Fill the masked token in the texts given as inputs.
29+
* @param {string[]} texts A list of texts with masked tokens.
30+
* @param {FillMaskPipelineOptions} [options] The options to use for masked language modelling.
31+
* @returns {Promise<FillMaskOutput[]>} An array where each entry corresponds to the predictions for an input text.
32+
*
33+
* @typedef {FillMaskPipelineCallbackSingle & FillMaskPipelineCallbackBatch} FillMaskPipelineCallback
2934
*
3035
* @typedef {TextPipelineConstructorArgs & FillMaskPipelineCallback & Disposable} FillMaskPipelineType
3136
*/
3237

3338
/**
3439
* Masked language modeling prediction pipeline using any `ModelWithLMHead`.
3540
*
41+
* **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `onnx-community/ettin-encoder-32m-ONNX`.
42+
* ```javascript
43+
* import { pipeline } from '@huggingface/transformers';
44+
*
45+
* const unmasker = await pipeline('fill-mask', 'onnx-community/ettin-encoder-32m-ONNX');
46+
* const output = await unmasker('The capital of France is [MASK].');
47+
* // [
48+
* // { score: 0.5151872038841248, token: 7785, token_str: ' Paris', sequence: 'The capital of France is Paris.' },
49+
* // { score: 0.033725105226039886, token: 42268, token_str: ' Lyon', sequence: 'The capital of France is Lyon.' },
50+
* // { score: 0.031234024092555046, token: 23397, token_str: ' Nancy', sequence: 'The capital of France is Nancy.' },
51+
* // { score: 0.02075139433145523, token: 30167, token_str: ' Brussels', sequence: 'The capital of France is Brussels.' },
52+
* // { score: 0.018962178379297256, token: 31955, token_str: ' Geneva', sequence: 'The capital of France is Geneva.' }
53+
* // ]
54+
* ```
55+
*
3656
* **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `Xenova/bert-base-uncased`.
3757
* ```javascript
58+
* import { pipeline } from '@huggingface/transformers';
59+
*
3860
* const unmasker = await pipeline('fill-mask', 'Xenova/bert-base-cased');
3961
* const output = await unmasker('The goal of life is [MASK].');
4062
* // [
41-
* // { token_str: 'survival', score: 0.06137419492006302, token: 8115, sequence: 'The goal of life is survival.' },
42-
* // { token_str: 'love', score: 0.03902450203895569, token: 1567, sequence: 'The goal of life is love.' },
43-
* // { token_str: 'happiness', score: 0.03253183513879776, token: 9266, sequence: 'The goal of life is happiness.' },
44-
* // { token_str: 'freedom', score: 0.018736306577920914, token: 4438, sequence: 'The goal of life is freedom.' },
45-
* // { token_str: 'life', score: 0.01859794743359089, token: 1297, sequence: 'The goal of life is life.' }
63+
* // { score: 0.11368396878242493, sequence: "The goal of life is survival.", token: 8115, token_str: "survival" },
64+
* // { score: 0.053510840982198715, sequence: "The goal of life is love.", token: 1567, token_str: "love" },
65+
* // { score: 0.05041185021400452, sequence: "The goal of life is happiness.", token: 9266, token_str: "happiness" },
66+
* // { score: 0.033218126744031906, sequence: "The goal of life is freedom.", token: 4438, token_str: "freedom" },
67+
* // { score: 0.03301157429814339, sequence: "The goal of life is success.", token: 2244, token_str: "success" },
4668
* // ]
4769
* ```
4870
*
4971
* **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `Xenova/bert-base-cased` (and return top result).
5072
* ```javascript
73+
* import { pipeline } from '@huggingface/transformers';
74+
*
5175
* const unmasker = await pipeline('fill-mask', 'Xenova/bert-base-cased');
5276
* const output = await unmasker('The Milky Way is a [MASK] galaxy.', { top_k: 1 });
53-
* // [{ token_str: 'spiral', score: 0.6299987435340881, token: 14061, sequence: 'The Milky Way is a spiral galaxy.' }]
77+
* // [{ score: 0.5982972383499146, sequence: "The Milky Way is a spiral galaxy.", token: 14061, token_str: "spiral" }]
5478
* ```
5579
*/
5680
export class FillMaskPipeline
5781
extends /** @type {new (options: TextPipelineConstructorArgs) => FillMaskPipelineType} */ (Pipeline)
5882
{
59-
/**
60-
* Create a new FillMaskPipeline.
61-
* @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline.
62-
*/
63-
constructor(options) {
64-
super(options);
65-
}
66-
67-
/** @type {FillMaskPipelineCallback} */
6883
async _call(texts, { top_k = 5 } = {}) {
84+
const { mask_token_id, mask_token } = this.tokenizer;
85+
6986
// Run tokenization
7087
const model_inputs = this.tokenizer(texts, {
7188
padding: true,
@@ -84,11 +101,11 @@ export class FillMaskPipeline
84101
const mask_token_index = ids.findIndex(
85102
(x) =>
86103
// We use == to match bigint with number
87-
// @ts-ignore
88-
x == this.tokenizer.mask_token_id,
104+
// @ts-expect-error TS2367
105+
x == mask_token_id,
89106
);
90107
if (mask_token_index === -1) {
91-
throw Error(`Mask token (${this.tokenizer.mask_token}) not found in text.`);
108+
throw Error(`Mask token (${mask_token}) not found in text.`);
92109
}
93110
const itemLogits = logits[i][mask_token_index];
94111

0 commit comments

Comments
 (0)