diff --git a/src/backends/onnx.js b/src/backends/onnx.js index b5923a596..9ac28e8df 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -22,6 +22,7 @@ import { env, apis } from '../env.js'; // In either case, we select the default export if it exists, otherwise we use the named export. import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web/webgpu'; +import { loadWasmBinary, loadWasmFactory } from './utils/cacheWasm.js'; export { Tensor } from 'onnxruntime-common'; @@ -141,6 +142,79 @@ const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV; */ let webInitChain = Promise.resolve(); +/** + * Promise that resolves when WASM binary has been loaded (if caching is enabled). + * This ensures we only attempt to load the WASM binary once. + * @type {Promise|null} + */ +let wasmLoadPromise = null; + +/** + * Ensures the WASM binary is loaded and cached before creating an inference session. + * Only runs once, even if called multiple times. + * + * @returns {Promise} + */ +async function ensureWasmLoaded() { + // If already loading or loaded, return the existing promise + if (wasmLoadPromise) { + return wasmLoadPromise; + } + + const shouldUseWasmCache = + env.useWasmCache && + typeof ONNX_ENV?.wasm?.wasmPaths === 'object' && + ONNX_ENV?.wasm?.wasmPaths?.wasm && + ONNX_ENV?.wasm?.wasmPaths?.mjs; + + // Check if we should load the WASM binary + if (!shouldUseWasmCache) { + wasmLoadPromise = Promise.resolve(); + return wasmLoadPromise; + } + + // Start loading the WASM binary + wasmLoadPromise = (async () => { + // At this point, we know wasmPaths is an object (not a string) because + // shouldUseWasmCache checks for wasmPaths.wasm and wasmPaths.mjs + const urls = /** @type {{ wasm: string, mjs: string }} */ (ONNX_ENV.wasm.wasmPaths); + + // Load and cache both the WASM binary and factory + await Promise.all([ + // Load and cache the WASM binary + urls.wasm + ? (async () => { + try { + const wasmBinary = await loadWasmBinary(urls.wasm); + if (wasmBinary) { + ONNX_ENV.wasm.wasmBinary = wasmBinary; + } + } catch (err) { + console.warn('Failed to pre-load WASM binary:', err); + } + })() + : Promise.resolve(), + + // Load and cache the WASM factory + urls.mjs + ? (async () => { + try { + const wasmFactoryBlob = await loadWasmFactory(urls.mjs); + if (wasmFactoryBlob) { + // @ts-ignore + ONNX_ENV.wasm.wasmPaths.mjs = wasmFactoryBlob; + } + } catch (err) { + console.warn('Failed to pre-load WASM factory:', err); + } + })() + : Promise.resolve(), + ]); + })(); + + return wasmLoadPromise; +} + /** * Create an ONNX inference session. * @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path. @@ -149,6 +223,8 @@ let webInitChain = Promise.resolve(); * @returns {Promise} The ONNX inference session. */ export async function createInferenceSession(buffer_or_path, session_options, session_config) { + await ensureWasmLoaded(); + const load = () => InferenceSession.create(buffer_or_path, session_options); const session = await (IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load()); session.config = session_config; @@ -201,15 +277,15 @@ if (ONNX_ENV?.wasm) { ONNX_ENV.wasm.wasmPaths = apis.IS_SAFARI ? { - mjs: `${wasmPathPrefix}/ort-wasm-simd-threaded.mjs`, - wasm: `${wasmPathPrefix}/ort-wasm-simd-threaded.wasm`, + mjs: `${wasmPathPrefix}ort-wasm-simd-threaded.mjs`, + wasm: `${wasmPathPrefix}ort-wasm-simd-threaded.wasm`, } - : wasmPathPrefix; + : { + mjs: `${wasmPathPrefix}ort-wasm-simd-threaded.asyncify.mjs`, + wasm: `${wasmPathPrefix}ort-wasm-simd-threaded.asyncify.wasm`, + }; } - // TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0 - // https://github.com/microsoft/onnxruntime/pull/21534 - // Users may wish to proxy the WASM backend to prevent the UI from freezing, // However, this is not necessary when using WebGPU, so we default to false. ONNX_ENV.wasm.proxy = false; diff --git a/src/backends/utils/cacheWasm.js b/src/backends/utils/cacheWasm.js new file mode 100644 index 000000000..0c8170596 --- /dev/null +++ b/src/backends/utils/cacheWasm.js @@ -0,0 +1,83 @@ +import { getCache } from '../../utils/cache.js'; + +/** + * Loads and caches a file from the given URL. + * @param {string} url The URL of the file to load. + * @returns {Promise} The response object, or null if loading failed. + */ +async function loadAndCacheFile(url) { + const fileName = url.split('/').pop(); + try { + const cache = await getCache(); + + // Try to get from cache first + if (cache) { + try { + return await cache.match(url); + } catch (e) { + console.warn(`Error reading ${fileName} from cache:`, e); + } + } + + // If not in cache, fetch it + const response = await fetch(url); + + if (!response.ok) { + throw new Error(`Failed to fetch ${fileName}: ${response.status} ${response.statusText}`); + } + + // Cache the response for future use + if (cache) { + try { + await cache.put(url, response.clone()); + } catch (e) { + console.warn(`Failed to cache ${fileName}:`, e); + } + } + + return response; + } catch (error) { + console.warn(`Failed to load ${fileName}:`, error); + return null; + } +} + +/** + * Loads and caches the WASM binary for ONNX Runtime. + * @param {string} wasmURL The URL of the WASM file to load. + * @returns {Promise} The WASM binary as an ArrayBuffer, or null if loading failed. + */ + +export async function loadWasmBinary(wasmURL) { + const response = await loadAndCacheFile(wasmURL); + if (!response || typeof response === 'string') return null; + + try { + return await response.arrayBuffer(); + } catch (error) { + console.warn('Failed to read WASM binary:', error); + return null; + } +} + +/** + * Loads and caches the WASM Factory for ONNX Runtime. + * @param {string} libURL The URL of the WASM Factory to load. + * @returns {Promise} The blob URL of the WASM Factory, or null if loading failed. + */ +export async function loadWasmFactory(libURL) { + const response = await loadAndCacheFile(libURL); + if (!response || typeof response === 'string') return null; + + try { + let code = await response.text(); + // Fix relative paths when loading factory from blob, overwrite import.meta.url with actual baseURL + const baseUrl = libURL.split('/').slice(0, -1).join('/'); + code = code.replace(/import\.meta\.url/g, `"${baseUrl}"`); + const blob = new Blob([code], { type: 'text/javascript' }); + return URL.createObjectURL(blob); + } catch (error) { + console.warn('Failed to read WASM binary:', error); + return null; + } +} diff --git a/src/env.js b/src/env.js index 7b8fc9d03..bf19628cf 100644 --- a/src/env.js +++ b/src/env.js @@ -152,9 +152,12 @@ const localModelPath = RUNNING_LOCALLY ? path.join(dirname__, DEFAULT_LOCAL_MODE * @property {boolean} useFSCache Whether to use the file system to cache files. By default, it is `true` if available. * @property {string|null} cacheDir The directory to use for caching files with the file system. By default, it is `./.cache`. * @property {boolean} useCustomCache Whether to use a custom cache system (defined by `customCache`), defaults to `false`. - * @property {Object|null} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which + * @property {import('./utils/cache.js').CacheInterface|null} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache. - * If you wish, you may also return a `Promise` from the `match` function if you'd like to use a file path instead of `Promise`. + * @property {boolean} useWasmCache Whether to pre-load and cache WASM binaries for ONNX Runtime. Defaults to `true` when cache is available. + * This can improve performance by avoiding repeated downloads of WASM files. Note: Only the WASM binary is cached. + * The MJS loader file still requires network access unless you use a Service Worker. + * @property {string} cacheKey The cache key to use for storing models and WASM binaries. Defaults to 'transformers-cache'. */ /** @type {TransformersEnvironment} */ @@ -185,6 +188,9 @@ export const env = { useCustomCache: false, customCache: null, + + useWasmCache: IS_WEB_CACHE_AVAILABLE || IS_FS_AVAILABLE, + cacheKey: 'transformers-cache', ////////////////////////////////////////////////////// }; diff --git a/src/utils/cache.js b/src/utils/cache.js new file mode 100644 index 000000000..c803546d5 --- /dev/null +++ b/src/utils/cache.js @@ -0,0 +1,82 @@ +import { apis, env } from '../env.js'; +import FileCache from './hub/FileCache.js'; + +/** + * @typedef {Object} CacheInterface + * @property {(request: string) => Promise} match + * Checks if a request is in the cache and returns the cached response if found. + * @property {(request: string, response: Response, progress_callback?: (data: {progress: number, loaded: number, total: number}) => void) => Promise} put + * Adds a response to the cache. + */ + +/** + * Retrieves an appropriate caching backend based on the environment configuration. + * Attempts to use custom cache, browser cache, or file system cache in that order of priority. + * @returns {Promise} + * @param file_cache_dir {string|null} Path to a directory in which a downloaded pretrained model configuration should be cached if using the file system cache. + */ +export async function getCache(file_cache_dir = null) { + // First, check if the a caching backend is available + // If no caching mechanism available, will download the file every time + let cache = null; + if (env.useCustomCache) { + // Allow the user to specify a custom cache system. + if (!env.customCache) { + throw Error('`env.useCustomCache=true`, but `env.customCache` is not defined.'); + } + + // Check that the required methods are defined: + if (!env.customCache.match || !env.customCache.put) { + throw new Error( + '`env.customCache` must be an object which implements the `match` and `put` functions of the Web Cache API. ' + + 'For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache', + ); + } + cache = env.customCache; + } + + if (!cache && env.useBrowserCache) { + if (typeof caches === 'undefined') { + throw Error('Browser cache is not available in this environment.'); + } + try { + // In some cases, the browser cache may be visible, but not accessible due to security restrictions. + // For example, when running an application in an iframe, if a user attempts to load the page in + // incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage': + // An attempt was made to break through the security policy of the user agent.` + // So, instead of crashing, we just ignore the error and continue without using the cache. + cache = await caches.open(env.cacheKey); + } catch (e) { + console.warn('An error occurred while opening the browser cache:', e); + } + } + + if (!cache && env.useFSCache) { + if (!apis.IS_FS_AVAILABLE) { + throw Error('File System Cache is not available in this environment.'); + } + + // If `cache_dir` is not specified, use the default cache directory + cache = new FileCache(file_cache_dir ?? env.cacheDir); + } + + return cache; +} + +/** + * Searches the cache for any of the provided names and returns the first match found. + * @param {CacheInterface} cache The cache to search + * @param {...string} names The names of the items to search for + * @returns {Promise} The item from the cache, or undefined if not found. + */ +export async function tryCache(cache, ...names) { + for (let name of names) { + try { + let result = await cache.match(name); + if (result) return result; + } catch (e) { + continue; + } + } + return undefined; +} diff --git a/src/utils/hub.js b/src/utils/hub.js index ca5967faf..07c4f492c 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -4,18 +4,22 @@ * @module utils/hub */ -import fs from 'node:fs'; -import path from 'node:path'; - import { apis, env } from '../env.js'; import { dispatchCallback } from './core.js'; +import FileResponse from './hub/FileResponse.js'; +import FileCache from './hub/FileCache.js'; +import { handleError, isValidUrl, pathJoin, isValidHfModelId, readResponse } from './hub/utils.js'; +import { getCache, tryCache } from './cache.js'; + +export { MAX_EXTERNAL_DATA_CHUNKS } from './hub/constants.js'; /** - * @typedef {boolean|number} ExternalData Whether to load the model using the external data format (used for models >= 2GB in size). - * If `true`, the model will be loaded using the external data format. - * If a number, this many chunks will be loaded using the external data format (of the form: "model.onnx_data[_{chunk_number}]"). + * @typedef {boolean|number} ExternalData + * Specifies whether to load the model using the external data format. + * - `false`: Do not use external data format + * - `true`: Use external data format with 1 chunk + * - `number`: Use external data format with the specified number of chunks */ -export const MAX_EXTERNAL_DATA_CHUNKS = 100; /** * @typedef {Object} PretrainedOptions Options for loading a pretrained model. @@ -45,165 +49,6 @@ export const MAX_EXTERNAL_DATA_CHUNKS = 100; * @typedef {PretrainedOptions & ModelSpecificPretrainedOptions} PretrainedModelOptions Options for loading a pretrained model. */ -/** - * Mapping from file extensions to MIME types. - */ -const CONTENT_TYPE_MAP = { - txt: 'text/plain', - html: 'text/html', - css: 'text/css', - js: 'text/javascript', - json: 'application/json', - png: 'image/png', - jpg: 'image/jpeg', - jpeg: 'image/jpeg', - gif: 'image/gif', -}; -class FileResponse { - /** - * Creates a new `FileResponse` object. - * @param {string} filePath - */ - constructor(filePath) { - this.filePath = filePath; - this.headers = new Headers(); - - this.exists = fs.existsSync(filePath); - if (this.exists) { - this.status = 200; - this.statusText = 'OK'; - - let stats = fs.statSync(filePath); - this.headers.set('content-length', stats.size.toString()); - - this.updateContentType(); - - const stream = fs.createReadStream(filePath); - this.body = new ReadableStream({ - start(controller) { - stream.on('data', (chunk) => controller.enqueue(chunk)); - stream.on('end', () => controller.close()); - stream.on('error', (err) => controller.error(err)); - }, - cancel() { - stream.destroy(); - }, - }); - } else { - this.status = 404; - this.statusText = 'Not Found'; - this.body = null; - } - } - - /** - * Updates the 'content-type' header property of the response based on the extension of - * the file specified by the filePath property of the current object. - * @returns {void} - */ - updateContentType() { - // Set content-type header based on file extension - const extension = this.filePath.toString().split('.').pop().toLowerCase(); - this.headers.set('content-type', CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream'); - } - - /** - * Clone the current FileResponse object. - * @returns {FileResponse} A new FileResponse object with the same properties as the current object. - */ - clone() { - let response = new FileResponse(this.filePath); - response.exists = this.exists; - response.status = this.status; - response.statusText = this.statusText; - response.headers = new Headers(this.headers); - return response; - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with an ArrayBuffer containing the file's contents. - * @returns {Promise} A Promise that resolves with an ArrayBuffer containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async arrayBuffer() { - const data = await fs.promises.readFile(this.filePath); - return /** @type {ArrayBuffer} */ (data.buffer); - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with a Blob containing the file's contents. - * @returns {Promise} A Promise that resolves with a Blob containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async blob() { - const data = await fs.promises.readFile(this.filePath); - return new Blob([/** @type {any} */ (data)], { type: this.headers.get('content-type') }); - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with a string containing the file's contents. - * @returns {Promise} A Promise that resolves with a string containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async text() { - const data = await fs.promises.readFile(this.filePath, 'utf8'); - return data; - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with a parsed JavaScript object containing the file's contents. - * - * @returns {Promise} A Promise that resolves with a parsed JavaScript object containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async json() { - return JSON.parse(await this.text()); - } -} - -/** - * Determines whether the given string is a valid URL. - * @param {string|URL} string The string to test for validity as an URL. - * @param {string[]} [protocols=null] A list of valid protocols. If specified, the protocol must be in this list. - * @param {string[]} [validHosts=null] A list of valid hostnames. If specified, the URL's hostname must be in this list. - * @returns {boolean} True if the string is a valid URL, false otherwise. - */ -function isValidUrl(string, protocols = null, validHosts = null) { - let url; - try { - url = new URL(string); - } catch (_) { - return false; - } - if (protocols && !protocols.includes(url.protocol)) { - return false; - } - if (validHosts && !validHosts.includes(url.hostname)) { - return false; - } - return true; -} - -const REPO_ID_REGEX = /^(\b[\w\-.]+\b\/)?\b[\w\-.]{1,96}\b$/; - -/** - * Tests whether a string is a valid Hugging Face model ID or not. - * Adapted from https://github.com/huggingface/huggingface_hub/blob/6378820ebb03f071988a96c7f3268f5bdf8f9449/src/huggingface_hub/utils/_validators.py#L119-L170 - * - * @param {string} string The string to test - * @returns {boolean} True if the string is a valid model ID, false otherwise. - */ -function isValidHfModelId(string) { - if (!REPO_ID_REGEX.test(string)) return false; - if (string.includes('..') || string.includes('--')) return false; - if (string.endsWith('.git') || string.endsWith('.ipynb')) return false; - return true; -} - /** * Helper function to get a file, using either the Fetch API or FileSystem API. * @@ -246,142 +91,6 @@ export async function getFile(urlOrPath) { } } -const ERROR_MAPPING = { - // 4xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses) - 400: 'Bad request error occurred while trying to load file', - 401: 'Unauthorized access to file', - 403: 'Forbidden access to file', - 404: 'Could not locate file', - 408: 'Request timeout error occurred while trying to load file', - - // 5xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#server_error_responses) - 500: 'Internal server error error occurred while trying to load file', - 502: 'Bad gateway error occurred while trying to load file', - 503: 'Service unavailable error occurred while trying to load file', - 504: 'Gateway timeout error occurred while trying to load file', -}; -/** - * Helper method to handle fatal errors that occur while trying to load a file from the Hugging Face Hub. - * @param {number} status The HTTP status code of the error. - * @param {string} remoteURL The URL of the file that could not be loaded. - * @param {boolean} fatal Whether to raise an error if the file could not be loaded. - * @returns {null} Returns `null` if `fatal = true`. - * @throws {Error} If `fatal = false`. - */ -function handleError(status, remoteURL, fatal) { - if (!fatal) { - // File was not loaded correctly, but it is optional. - // TODO in future, cache the response? - return null; - } - - const message = ERROR_MAPPING[status] ?? `Error (${status}) occurred while trying to load file`; - throw Error(`${message}: "${remoteURL}".`); -} - -class FileCache { - /** - * Instantiate a `FileCache` object. - * @param {string} path - */ - constructor(path) { - this.path = path; - } - - /** - * Checks whether the given request is in the cache. - * @param {string} request - * @returns {Promise} - */ - async match(request) { - let filePath = path.join(this.path, request); - let file = new FileResponse(filePath); - - if (file.exists) { - return file; - } else { - return undefined; - } - } - - /** - * Adds the given response to the cache. - * @param {string} request - * @param {Response} response - * @param {(data: {progress: number, loaded: number, total: number}) => void} [progress_callback] Optional. - * The function to call with progress updates - * @returns {Promise} - */ - async put(request, response, progress_callback = undefined) { - let filePath = path.join(this.path, request); - - try { - const contentLength = response.headers.get('Content-Length'); - const total = parseInt(contentLength ?? '0'); - let loaded = 0; - - await fs.promises.mkdir(path.dirname(filePath), { recursive: true }); - const fileStream = fs.createWriteStream(filePath); - const reader = response.body.getReader(); - - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; - } - - await new Promise((resolve, reject) => { - fileStream.write(value, (err) => { - if (err) { - reject(err); - return; - } - resolve(); - }); - }); - - loaded += value.length; - const progress = total ? (loaded / total) * 100 : 0; - - progress_callback?.({ progress, loaded, total }); - } - - fileStream.close(); - } catch (error) { - // Clean up the file if an error occurred during download - try { - await fs.promises.unlink(filePath); - } catch {} - throw error; - } - } - - // TODO add the rest? - // addAll(requests: RequestInfo[]): Promise; - // delete(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; - // keys(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; - // match(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; - // matchAll(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; -} - -/** - * - * @param {FileCache|Cache} cache The cache to search - * @param {string[]} names The names of the item to search for - * @returns {Promise} The item from the cache, or undefined if not found. - */ -async function tryCache(cache, ...names) { - for (let name of names) { - try { - let result = await cache.match(name); - if (result) return result; - } catch (e) { - continue; - } - } - return undefined; -} - /** * Retrieves a file from either a remote URL using the Fetch API or from the local file system using the FileSystem API. * If the filesystem is available and `env.useCache = true`, the file will be downloaded and cached. @@ -419,49 +128,8 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti file: filename, }); - // First, check if the a caching backend is available - // If no caching mechanism available, will download the file every time - let cache; - if (!cache && env.useCustomCache) { - // Allow the user to specify a custom cache system. - if (!env.customCache) { - throw Error('`env.useCustomCache=true`, but `env.customCache` is not defined.'); - } - - // Check that the required methods are defined: - if (!env.customCache.match || !env.customCache.put) { - throw new Error( - '`env.customCache` must be an object which implements the `match` and `put` functions of the Web Cache API. ' + - 'For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache', - ); - } - cache = env.customCache; - } - - if (!cache && env.useBrowserCache) { - if (typeof caches === 'undefined') { - throw Error('Browser cache is not available in this environment.'); - } - try { - // In some cases, the browser cache may be visible, but not accessible due to security restrictions. - // For example, when running an application in an iframe, if a user attempts to load the page in - // incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage': - // An attempt was made to break through the security policy of the user agent.` - // So, instead of crashing, we just ignore the error and continue without using the cache. - cache = await caches.open('transformers-cache'); - } catch (e) { - console.warn('An error occurred while opening the browser cache:', e); - } - } - - if (!cache && env.useFSCache) { - if (!apis.IS_FS_AVAILABLE) { - throw Error('File System Cache is not available in this environment.'); - } - - // If `cache_dir` is not specified, use the default cache directory - cache = new FileCache(options.cache_dir ?? env.cacheDir); - } + /** @type {import('./cache.js').CacheInterface | null} */ + const cache = await getCache(options?.cache_dir); const revision = options.revision ?? 'main'; const requestURL = pathJoin(path_or_repo_id, filename); @@ -491,7 +159,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // Whether to cache the final response in the end. let toCacheResponse = false; - /** @type {Response|FileResponse|undefined} */ + /** @type {Response|import('./hub/FileResponse.js').default|undefined|string} */ let response; if (cache) { @@ -503,7 +171,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti } const cacheHit = response !== undefined; - if (response === undefined) { + if (!cacheHit) { // Caching not available, or file is not cached, so we perform the request if (env.allowLocalModels) { @@ -528,7 +196,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti } } - if (response === undefined || response.status === 404) { + if (response === undefined || (typeof response !== 'string' && response.status === 404)) { // File not found locally. This means either: // - The user has disabled local file access (`env.allowLocalModels=false`) // - the path is a valid HTTP url (`response === undefined`) @@ -585,37 +253,39 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti /** @type {Uint8Array} */ let buffer; - if (!options.progress_callback) { - // If no progress callback is specified, we can use the `.arrayBuffer()` - // method to read the response. - buffer = new Uint8Array(await response.arrayBuffer()); - } else if ( - cacheHit && // The item is being read from the cache - typeof navigator !== 'undefined' && - /firefox/i.test(navigator.userAgent) // We are in Firefox - ) { - // Due to bug in Firefox, we cannot display progress when loading from cache. - // Fortunately, since this should be instantaneous, this should not impact users too much. - buffer = new Uint8Array(await response.arrayBuffer()); - - // For completeness, we still fire the final progress callback - dispatchCallback(options.progress_callback, { - status: 'progress', - name: path_or_repo_id, - file: filename, - progress: 100, - loaded: buffer.length, - total: buffer.length, - }); - } else { - buffer = await readResponse(response, (data) => { + if (typeof response !== 'string') { + if (!options.progress_callback) { + // If no progress callback is specified, we can use the `.arrayBuffer()` + // method to read the response. + buffer = new Uint8Array(await response.arrayBuffer()); + } else if ( + cacheHit && // The item is being read from the cache + typeof navigator !== 'undefined' && + /firefox/i.test(navigator.userAgent) // We are in Firefox + ) { + // Due to bug in Firefox, we cannot display progress when loading from cache. + // Fortunately, since this should be instantaneous, this should not impact users too much. + buffer = new Uint8Array(await response.arrayBuffer()); + + // For completeness, we still fire the final progress callback dispatchCallback(options.progress_callback, { status: 'progress', name: path_or_repo_id, file: filename, - ...data, + progress: 100, + loaded: buffer.length, + total: buffer.length, + }); + } else { + buffer = await readResponse(response, (data) => { + dispatchCallback(options.progress_callback, { + status: 'progress', + name: path_or_repo_id, + file: filename, + ...data, + }); }); - }); + } } result = buffer; } @@ -641,7 +311,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti }) : undefined; await cache.put(cacheKey, /** @type {Response} */ (response), wrapped_progress); - } else { + } else if (typeof response !== 'string') { // NOTE: We use `new Response(buffer, ...)` instead of `response.clone()` to handle LFS files await cache .put( @@ -726,73 +396,3 @@ export async function getModelJSON(modelPath, fileName, fatal = true, options = return JSON.parse(text); } -/** - * Read and track progress when reading a Response object - * - * @param {Response|FileResponse} response The Response object to read - * @param {(data: {progress: number, loaded: number, total: number}) => void} progress_callback The function to call with progress updates - * @returns {Promise} A Promise that resolves with the Uint8Array buffer - */ -async function readResponse(response, progress_callback) { - const contentLength = response.headers.get('Content-Length'); - if (contentLength === null) { - console.warn('Unable to determine content-length from response headers. Will expand buffer when needed.'); - } - let total = parseInt(contentLength ?? '0'); - let buffer = new Uint8Array(total); - let loaded = 0; - - const reader = response.body.getReader(); - async function read() { - const { done, value } = await reader.read(); - if (done) return; - - const newLoaded = loaded + value.length; - if (newLoaded > total) { - total = newLoaded; - - // Adding the new data will overflow buffer. - // In this case, we extend the buffer - const newBuffer = new Uint8Array(total); - - // copy contents - newBuffer.set(buffer); - - buffer = newBuffer; - } - buffer.set(value, loaded); - loaded = newLoaded; - - const progress = (loaded / total) * 100; - - // Call your function here - progress_callback({ progress, loaded, total }); - - return read(); - } - - // Actually read - await read(); - - return buffer; -} - -/** - * Joins multiple parts of a path into a single path, while handling leading and trailing slashes. - * - * @param {...string} parts Multiple parts of a path. - * @returns {string} A string representing the joined path. - */ -function pathJoin(...parts) { - // https://stackoverflow.com/a/55142565 - parts = parts.map((part, index) => { - if (index) { - part = part.replace(new RegExp('^/'), ''); - } - if (index !== parts.length - 1) { - part = part.replace(new RegExp('/$'), ''); - } - return part; - }); - return parts.join('/'); -} diff --git a/src/utils/hub/FileCache.js b/src/utils/hub/FileCache.js new file mode 100644 index 000000000..c227c8c8f --- /dev/null +++ b/src/utils/hub/FileCache.js @@ -0,0 +1,92 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import FileResponse from './FileResponse.js'; + +/** + * File system cache implementation that implements the CacheInterface. + * Provides `match` and `put` methods compatible with the Web Cache API. + */ +export default class FileCache { + /** + * Instantiate a `FileCache` object. + * @param {string} path + */ + constructor(path) { + this.path = path; + } + + /** + * Checks whether the given request is in the cache. + * @param {string} request + * @returns {Promise} + */ + async match(request) { + let filePath = path.join(this.path, request); + let file = new FileResponse(filePath); + + if (file.exists) { + return file; + } else { + return undefined; + } + } + + /** + * Adds the given response to the cache. + * @param {string} request + * @param {Response} response + * @param {(data: {progress: number, loaded: number, total: number}) => void} [progress_callback] Optional. + * The function to call with progress updates + * @returns {Promise} + */ + async put(request, response, progress_callback = undefined) { + let filePath = path.join(this.path, request); + + try { + const contentLength = response.headers.get('Content-Length'); + const total = parseInt(contentLength ?? '0'); + let loaded = 0; + + await fs.promises.mkdir(path.dirname(filePath), { recursive: true }); + const fileStream = fs.createWriteStream(filePath); + const reader = response.body.getReader(); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + await new Promise((resolve, reject) => { + fileStream.write(value, (err) => { + if (err) { + reject(err); + return; + } + resolve(); + }); + }); + + loaded += value.length; + const progress = total ? (loaded / total) * 100 : 0; + + progress_callback?.({ progress, loaded, total }); + } + + fileStream.close(); + } catch (error) { + // Clean up the file if an error occurred during download + try { + await fs.promises.unlink(filePath); + } catch {} + throw error; + } + } + + // TODO add the rest? + // addAll(requests: RequestInfo[]): Promise; + // delete(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; + // keys(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; + // match(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; + // matchAll(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; +} diff --git a/src/utils/hub/FileResponse.js b/src/utils/hub/FileResponse.js new file mode 100644 index 000000000..2e421f6bc --- /dev/null +++ b/src/utils/hub/FileResponse.js @@ -0,0 +1,121 @@ +import fs from 'node:fs'; + +/** + * Mapping from file extensions to MIME types. + */ +const CONTENT_TYPE_MAP = { + txt: 'text/plain', + html: 'text/html', + css: 'text/css', + js: 'text/javascript', + json: 'application/json', + png: 'image/png', + jpg: 'image/jpeg', + jpeg: 'image/jpeg', + gif: 'image/gif', +}; + +export default class FileResponse { + /** + * Creates a new `FileResponse` object. + * @param {string} filePath + */ + constructor(filePath) { + this.filePath = filePath; + this.headers = new Headers(); + + this.exists = fs.existsSync(filePath); + if (this.exists) { + this.status = 200; + this.statusText = 'OK'; + + let stats = fs.statSync(filePath); + this.headers.set('content-length', stats.size.toString()); + + this.updateContentType(); + + const stream = fs.createReadStream(filePath); + this.body = new ReadableStream({ + start(controller) { + stream.on('data', (chunk) => controller.enqueue(chunk)); + stream.on('end', () => controller.close()); + stream.on('error', (err) => controller.error(err)); + }, + cancel() { + stream.destroy(); + }, + }); + } else { + this.status = 404; + this.statusText = 'Not Found'; + this.body = null; + } + } + + /** + * Updates the 'content-type' header property of the response based on the extension of + * the file specified by the filePath property of the current object. + * @returns {void} + */ + updateContentType() { + // Set content-type header based on file extension + const extension = this.filePath.toString().split('.').pop().toLowerCase(); + this.headers.set('content-type', CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream'); + } + + /** + * Clone the current FileResponse object. + * @returns {FileResponse} A new FileResponse object with the same properties as the current object. + */ + clone() { + let response = new FileResponse(this.filePath); + response.exists = this.exists; + response.status = this.status; + response.statusText = this.statusText; + response.headers = new Headers(this.headers); + return response; + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with an ArrayBuffer containing the file's contents. + * @returns {Promise} A Promise that resolves with an ArrayBuffer containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async arrayBuffer() { + const data = await fs.promises.readFile(this.filePath); + return /** @type {ArrayBuffer} */ (data.buffer); + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with a Blob containing the file's contents. + * @returns {Promise} A Promise that resolves with a Blob containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async blob() { + const data = await fs.promises.readFile(this.filePath); + return new Blob([/** @type {any} */ (data)], { type: this.headers.get('content-type') }); + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with a string containing the file's contents. + * @returns {Promise} A Promise that resolves with a string containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async text() { + return await fs.promises.readFile(this.filePath, 'utf8'); + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with a parsed JavaScript object containing the file's contents. + * + * @returns {Promise} A Promise that resolves with a parsed JavaScript object containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async json() { + return JSON.parse(await this.text()); + } +} diff --git a/src/utils/hub/constants.js b/src/utils/hub/constants.js new file mode 100644 index 000000000..b87c7e55d --- /dev/null +++ b/src/utils/hub/constants.js @@ -0,0 +1,18 @@ +export const ERROR_MAPPING = { + // 4xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses) + 400: 'Bad request error occurred while trying to load file', + 401: 'Unauthorized access to file', + 403: 'Forbidden access to file', + 404: 'Could not locate file', + 408: 'Request timeout error occurred while trying to load file', + + // 5xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#server_error_responses) + 500: 'Internal server error error occurred while trying to load file', + 502: 'Bad gateway error occurred while trying to load file', + 503: 'Service unavailable error occurred while trying to load file', + 504: 'Gateway timeout error occurred while trying to load file', +}; + +export const MAX_EXTERNAL_DATA_CHUNKS = 100; + +export const REPO_ID_REGEX = /^(\b[\w\-.]+\b\/)?\b[\w\-.]{1,96}\b$/; diff --git a/src/utils/hub/utils.js b/src/utils/hub/utils.js new file mode 100644 index 000000000..3f8bc8354 --- /dev/null +++ b/src/utils/hub/utils.js @@ -0,0 +1,128 @@ +import { ERROR_MAPPING, REPO_ID_REGEX } from './constants.js'; + +/** + * Joins multiple parts of a path into a single path, while handling leading and trailing slashes. + * + * @param {...string} parts Multiple parts of a path. + * @returns {string} A string representing the joined path. + */ +export function pathJoin(...parts) { + // https://stackoverflow.com/a/55142565 + parts = parts.map((part, index) => { + if (index) { + part = part.replace(new RegExp('^/'), ''); + } + if (index !== parts.length - 1) { + part = part.replace(new RegExp('/$'), ''); + } + return part; + }); + return parts.join('/'); +} + +/** + * Determines whether the given string is a valid URL. + * @param {string|URL} string The string to test for validity as an URL. + * @param {string[]} [protocols=null] A list of valid protocols. If specified, the protocol must be in this list. + * @param {string[]} [validHosts=null] A list of valid hostnames. If specified, the URL's hostname must be in this list. + * @returns {boolean} True if the string is a valid URL, false otherwise. + */ +export function isValidUrl(string, protocols = null, validHosts = null) { + let url; + try { + url = new URL(string); + } catch (_) { + return false; + } + if (protocols && !protocols.includes(url.protocol)) { + return false; + } + if (validHosts && !validHosts.includes(url.hostname)) { + return false; + } + return true; +} + +/** + * Tests whether a string is a valid Hugging Face model ID or not. + * Adapted from https://github.com/huggingface/huggingface_hub/blob/6378820ebb03f071988a96c7f3268f5bdf8f9449/src/huggingface_hub/utils/_validators.py#L119-L170 + * + * @param {string} string The string to test + * @returns {boolean} True if the string is a valid model ID, false otherwise. + */ +export function isValidHfModelId(string) { + if (!REPO_ID_REGEX.test(string)) return false; + if (string.includes('..') || string.includes('--')) return false; + if (string.endsWith('.git') || string.endsWith('.ipynb')) return false; + return true; +} + +/** + * Helper method to handle fatal errors that occur while trying to load a file from the Hugging Face Hub. + * @param {number} status The HTTP status code of the error. + * @param {string} remoteURL The URL of the file that could not be loaded. + * @param {boolean} fatal Whether to raise an error if the file could not be loaded. + * @returns {null} Returns `null` if `fatal = true`. + * @throws {Error} If `fatal = false`. + */ +export function handleError(status, remoteURL, fatal) { + if (!fatal) { + // File was not loaded correctly, but it is optional. + // TODO in future, cache the response? + return null; + } + + const message = ERROR_MAPPING[status] ?? `Error (${status}) occurred while trying to load file`; + throw Error(`${message}: "${remoteURL}".`); +} + +/** + * Read and track progress when reading a Response object + * + * @param {Response|import('./FileResponse.js').default} response The Response object to read + * @param {(data: {progress: number, loaded: number, total: number}) => void} progress_callback The function to call with progress updates + * @returns {Promise} A Promise that resolves with the Uint8Array buffer + */ +export async function readResponse(response, progress_callback) { + const contentLength = response.headers.get('Content-Length'); + if (contentLength === null) { + console.warn('Unable to determine content-length from response headers. Will expand buffer when needed.'); + } + let total = parseInt(contentLength ?? '0'); + let buffer = new Uint8Array(total); + let loaded = 0; + + const reader = response.body.getReader(); + async function read() { + const { done, value } = await reader.read(); + if (done) return; + + const newLoaded = loaded + value.length; + if (newLoaded > total) { + total = newLoaded; + + // Adding the new data will overflow buffer. + // In this case, we extend the buffer + const newBuffer = new Uint8Array(total); + + // copy contents + newBuffer.set(buffer); + + buffer = newBuffer; + } + buffer.set(value, loaded); + loaded = newLoaded; + + const progress = (loaded / total) * 100; + + // Call your function here + progress_callback({ progress, loaded, total }); + + return read(); + } + + // Actually read + await read(); + + return buffer; +} diff --git a/webpack.config.js b/webpack.config.js index 4d8edb24f..d1e264ac2 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -171,12 +171,14 @@ const NODE_EXTERNAL_MODULES = [ "node:fs", "node:path", "node:url", + "node:stream", + "node:stream/promises", ]; // Do not bundle node-only packages when bundling for the web. // NOTE: We can exclude the "node:" prefix for built-in modules here, // since we apply the `StripNodePrefixPlugin` to strip it. -const WEB_IGNORE_MODULES = ["onnxruntime-node", "sharp", "fs", "path", "url"]; +const WEB_IGNORE_MODULES = ["onnxruntime-node", "sharp", "fs", "path", "url", "stream", "stream/promises"]; // Do not bundle the following modules with webpack (mark as external) const WEB_EXTERNAL_MODULES = ["onnxruntime-common", "onnxruntime-web"];