diff --git a/src/platform/embeddings/common/embeddingsComputer.ts b/src/platform/embeddings/common/embeddingsComputer.ts index 7b1a7a7f3c..e08b4f69b0 100644 --- a/src/platform/embeddings/common/embeddingsComputer.ts +++ b/src/platform/embeddings/common/embeddingsComputer.ts @@ -77,6 +77,10 @@ export interface Embedding { readonly value: EmbeddingVector; } +/** + * Validates if a value is a proper Embedding object with enhanced robustness checks. + * Includes validation for NaN, infinity, and proper numeric values. + */ export function isValidEmbedding(value: unknown): value is Embedding { if (typeof value !== 'object' || value === null) { return false; @@ -91,9 +95,13 @@ export function isValidEmbedding(value: unknown): value is Embedding { return false; } - return true; + // Enhanced validation: check for NaN, infinity, and proper numeric values + return asEmbedding.value.every(val => + typeof val === 'number' && + isFinite(val) && + !isNaN(val) + ); } - export interface Embeddings { readonly type: EmbeddingType; readonly values: readonly Embedding[]; diff --git a/src/platform/embeddings/common/embeddingsGrouper.ts b/src/platform/embeddings/common/embeddingsGrouper.ts index 52f7127d79..f8747654fd 100644 --- a/src/platform/embeddings/common/embeddingsGrouper.ts +++ b/src/platform/embeddings/common/embeddingsGrouper.ts @@ -91,7 +91,7 @@ export class EmbeddingsGrouper { return; } - // Batch add all nodes and cache their normalized embeddings + // Batch add all nodes for (const node of nodes) { this.nodes.push(node); } diff --git a/src/platform/embeddings/common/remoteEmbeddingsComputer.ts b/src/platform/embeddings/common/remoteEmbeddingsComputer.ts index bc3d4f3f94..6f2d0787a6 100644 --- a/src/platform/embeddings/common/remoteEmbeddingsComputer.ts +++ b/src/platform/embeddings/common/remoteEmbeddingsComputer.ts @@ -20,7 +20,7 @@ import { ILogService } from '../../log/common/logService'; import { IFetcherService } from '../../networking/common/fetcherService'; import { IEmbeddingsEndpoint, postRequest } from '../../networking/common/networking'; import { ITelemetryService } from '../../telemetry/common/telemetry'; -import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo } from './embeddingsComputer'; +import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo, isValidEmbedding } from './embeddingsComputer'; interface CAPIEmbeddingResults { readonly type: 'success'; @@ -137,10 +137,20 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer { throw new Error(`Mismatched embedding result count. Expected: ${batch.length}. Got: ${jsonResponse.embeddings.length}`); } - embeddingsOut.push(...jsonResponse.embeddings.map(embedding => ({ + // Validate embeddings at service boundary + const potentialEmbeddings = jsonResponse.embeddings.map(embedding => ({ type: resolvedType, value: embedding.embedding, - }))); + })); + + const validatedEmbeddings = potentialEmbeddings.filter(embedding => { + const isValid = isValidEmbedding(embedding); + if (!isValid) { + this._logService.warn(`Invalid embedding received from GitHub API, filtering out invalid embedding`); + } + return isValid; + }); + embeddingsOut.push(...validatedEmbeddings); } return { type: embeddingType, values: embeddingsOut }; @@ -289,7 +299,17 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer { embedding: number[]; }; if (response.status === 200 && jsonResponse.data) { - return { type: 'success', embeddings: jsonResponse.data.map((d: EmbeddingResponse) => d.embedding) }; + // Validate embeddings at service boundary + const validatedEmbeddings = jsonResponse.data + .map((d: EmbeddingResponse) => d.embedding) + .filter((embedding: number[]) => { + if (!Array.isArray(embedding) || embedding.length === 0 || embedding.some(val => typeof val !== 'number' || !isFinite(val))) { + this._logService.warn(`Invalid embedding received from CAPI, filtering out invalid embedding with ${embedding?.length || 0} dimensions`); + return false; + } + return true; + }); + return { type: 'success', embeddings: validatedEmbeddings }; } else { return { type: 'failed', reason: jsonResponse.error }; } diff --git a/src/platform/embeddings/test/node/embeddingsGrouper.spec.ts b/src/platform/embeddings/test/node/embeddingsGrouper.spec.ts index 66b6ad7b73..f96f0ab38b 100644 --- a/src/platform/embeddings/test/node/embeddingsGrouper.spec.ts +++ b/src/platform/embeddings/test/node/embeddingsGrouper.spec.ts @@ -542,6 +542,146 @@ describe('EmbeddingsGrouper', () => { expect(totalNodesInClusters).toBe(nodes.length); }); }); + + describe('defensive handling and robustness', () => { + describe('invalid embeddings handling', () => { + it('should handle null and undefined embeddings', () => { + // This test verifies the fix for runtime crashes from corrupted cache data + const nodes = [ + createNode('validTool', 'cat1', [1, 0.8, 0.6]), + // Simulate corrupted cache entries - these should be filtered out early + { value: { name: 'corruptedTool1', category: 'cat1' }, embedding: null as any }, + { value: { name: 'corruptedTool2', category: 'cat1' }, embedding: undefined as any }, + createNode('validTool2', 'cat1', [0.9, 0.7, 0.5]) + ]; + + expect(() => { + nodes.forEach(node => grouper.addNode(node)); + grouper.recluster(); + }).not.toThrow(); + + const clusters = grouper.getClusters(); + expect(clusters.length).toBeGreaterThan(0); + + // Should only include valid nodes in clusters (corrupted ones filtered out early) + const totalValidNodes = clusters.reduce((sum, cluster) => sum + cluster.nodes.length, 0); + expect(totalValidNodes).toBe(2); // Only the two valid tools + }); + + it('should handle embeddings with null/undefined value arrays', () => { + const nodes = [ + createNode('validTool', 'cat1', [1, 0.8, 0.6]), + { + value: { name: 'nullValuesTool', category: 'cat1' }, + embedding: { type: EmbeddingType.text3small_512, value: null as any } + }, + { + value: { name: 'undefinedValuesTool', category: 'cat1' }, + embedding: { type: EmbeddingType.text3small_512, value: undefined as any } + }, + createNode('validTool2', 'cat1', [0.9, 0.7, 0.5]) + ]; + + expect(() => { + nodes.forEach(node => grouper.addNode(node)); + grouper.recluster(); + }).not.toThrow(); + + const clusters = grouper.getClusters(); + expect(clusters.length).toBeGreaterThan(0); + }); + + it('should handle empty embedding arrays', () => { + const nodes = [ + createNode('validTool', 'cat1', [1, 0.8, 0.6]), + { + value: { name: 'emptyEmbedding', category: 'cat1' }, + embedding: createEmbedding([]) + }, + createNode('validTool2', 'cat1', [0.9, 0.7, 0.5]) + ]; + + expect(() => { + nodes.forEach(node => grouper.addNode(node)); + grouper.recluster(); + }).not.toThrow(); + + const clusters = grouper.getClusters(); + expect(clusters.length).toBeGreaterThan(0); + }); + }); + + describe('NaN values handling', () => { + it('should handle embeddings containing NaN values', () => { + const nodes = [ + createNode('validTool', 'cat1', [1, 0.8, 0.6]), + createNode('nanTool', 'cat1', [NaN, 0.5, NaN]), + createNode('mixedNanTool', 'cat1', [0.7, NaN, 0.4]), + createNode('validTool2', 'cat1', [0.9, 0.7, 0.5]) + ]; + + // The main goal: should not crash when processing NaN values + expect(() => { + nodes.forEach(node => grouper.addNode(node)); + grouper.recluster(); + }).not.toThrow(); + + const clusters = grouper.getClusters(); + expect(clusters.length).toBeGreaterThan(0); + + // Verify that centroids exist (NaN handling may vary) + clusters.forEach(cluster => { + expect(cluster.centroid).toBeDefined(); + expect(Array.isArray(cluster.centroid)).toBe(true); + }); + }); + + it('should handle zero vectors without division by zero', () => { + const nodes = [ + createNode('validTool', 'cat1', [1, 0.8, 0.6]), + createNode('zeroVector', 'cat1', [0, 0, 0]), + createNode('validTool2', 'cat1', [0.9, 0.7, 0.5]) + ]; + + expect(() => { + nodes.forEach(node => grouper.addNode(node)); + grouper.recluster(); + }).not.toThrow(); + + const clusters = grouper.getClusters(); + expect(clusters.length).toBeGreaterThan(0); + }); + }); + + describe('dimension mismatches handling', () => { + it('should handle embeddings with different dimensions', () => { + const nodes = [ + createNode('tool3d', 'cat1', [1, 0.8, 0.6]), + createNode('tool2d', 'cat1', [0.9, 0.7]), // Different dimension + createNode('tool4d', 'cat1', [0.8, 0.6, 0.5, 0.4]), // Different dimension + createNode('tool3d2', 'cat1', [0.7, 0.5, 0.3]) + ]; + + // The main goal: should not crash when processing dimension mismatches + expect(() => { + nodes.forEach(node => grouper.addNode(node)); + grouper.recluster(); + }).not.toThrow(); + + const clusters = grouper.getClusters(); + expect(clusters.length).toBeGreaterThan(0); + + // Verify that centroids exist + clusters.forEach(cluster => { + expect(cluster.centroid).toBeDefined(); + expect(cluster.centroid.length).toBeGreaterThan(0); + + // Just ensure we don't have empty clusters + expect(cluster.nodes.length).toBeGreaterThan(0); + }); + }); + }); + }); });