Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/platform/embeddings/common/embeddingsComputer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ export interface Embedding {
readonly value: EmbeddingVector;
}

/**
* Validates if a value is a proper Embedding object with enhanced robustness checks.
* Includes validation for NaN, infinity, and proper numeric values.
*/
export function isValidEmbedding(value: unknown): value is Embedding {
if (typeof value !== 'object' || value === null) {
return false;
Expand All @@ -91,9 +95,13 @@ export function isValidEmbedding(value: unknown): value is Embedding {
return false;
}

return true;
// Enhanced validation: check for NaN, infinity, and proper numeric values
return asEmbedding.value.every(val =>
typeof val === 'number' &&
isFinite(val) &&
!isNaN(val)
);
}

export interface Embeddings {
readonly type: EmbeddingType;
readonly values: readonly Embedding[];
Expand Down
2 changes: 1 addition & 1 deletion src/platform/embeddings/common/embeddingsGrouper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export class EmbeddingsGrouper<T> {
return;
}

// Batch add all nodes and cache their normalized embeddings
// Batch add all nodes
for (const node of nodes) {
this.nodes.push(node);
}
Expand Down
28 changes: 24 additions & 4 deletions src/platform/embeddings/common/remoteEmbeddingsComputer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { ILogService } from '../../log/common/logService';
import { IFetcherService } from '../../networking/common/fetcherService';
import { IEmbeddingsEndpoint, postRequest } from '../../networking/common/networking';
import { ITelemetryService } from '../../telemetry/common/telemetry';
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo } from './embeddingsComputer';
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo, isValidEmbedding } from './embeddingsComputer';

interface CAPIEmbeddingResults {
readonly type: 'success';
Expand Down Expand Up @@ -137,10 +137,20 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
throw new Error(`Mismatched embedding result count. Expected: ${batch.length}. Got: ${jsonResponse.embeddings.length}`);
}

embeddingsOut.push(...jsonResponse.embeddings.map(embedding => ({
// Validate embeddings at service boundary
const potentialEmbeddings = jsonResponse.embeddings.map(embedding => ({
type: resolvedType,
value: embedding.embedding,
})));
}));

const validatedEmbeddings = potentialEmbeddings.filter(embedding => {
const isValid = isValidEmbedding(embedding);
if (!isValid) {
this._logService.warn(`Invalid embedding received from GitHub API, filtering out invalid embedding`);
}
return isValid;
});
embeddingsOut.push(...validatedEmbeddings);
}

return { type: embeddingType, values: embeddingsOut };
Expand Down Expand Up @@ -289,7 +299,17 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
embedding: number[];
};
if (response.status === 200 && jsonResponse.data) {
return { type: 'success', embeddings: jsonResponse.data.map((d: EmbeddingResponse) => d.embedding) };
// Validate embeddings at service boundary
const validatedEmbeddings = jsonResponse.data
.map((d: EmbeddingResponse) => d.embedding)
.filter((embedding: number[]) => {
if (!Array.isArray(embedding) || embedding.length === 0 || embedding.some(val => typeof val !== 'number' || !isFinite(val))) {
this._logService.warn(`Invalid embedding received from CAPI, filtering out invalid embedding with ${embedding?.length || 0} dimensions`);
return false;
}
return true;
});
return { type: 'success', embeddings: validatedEmbeddings };
} else {
return { type: 'failed', reason: jsonResponse.error };
}
Expand Down
140 changes: 140 additions & 0 deletions src/platform/embeddings/test/node/embeddingsGrouper.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,146 @@ describe('EmbeddingsGrouper', () => {
expect(totalNodesInClusters).toBe(nodes.length);
});
});

describe('defensive handling and robustness', () => {
describe('invalid embeddings handling', () => {
it('should handle null and undefined embeddings', () => {
// This test verifies the fix for runtime crashes from corrupted cache data
const nodes = [
createNode('validTool', 'cat1', [1, 0.8, 0.6]),
// Simulate corrupted cache entries - these should be filtered out early
{ value: { name: 'corruptedTool1', category: 'cat1' }, embedding: null as any },
{ value: { name: 'corruptedTool2', category: 'cat1' }, embedding: undefined as any },
createNode('validTool2', 'cat1', [0.9, 0.7, 0.5])
];

expect(() => {
nodes.forEach(node => grouper.addNode(node));
grouper.recluster();
}).not.toThrow();

const clusters = grouper.getClusters();
expect(clusters.length).toBeGreaterThan(0);

// Should only include valid nodes in clusters (corrupted ones filtered out early)
const totalValidNodes = clusters.reduce((sum, cluster) => sum + cluster.nodes.length, 0);
expect(totalValidNodes).toBe(2); // Only the two valid tools
});

it('should handle embeddings with null/undefined value arrays', () => {
const nodes = [
createNode('validTool', 'cat1', [1, 0.8, 0.6]),
{
value: { name: 'nullValuesTool', category: 'cat1' },
embedding: { type: EmbeddingType.text3small_512, value: null as any }
},
{
value: { name: 'undefinedValuesTool', category: 'cat1' },
embedding: { type: EmbeddingType.text3small_512, value: undefined as any }
},
createNode('validTool2', 'cat1', [0.9, 0.7, 0.5])
];

expect(() => {
nodes.forEach(node => grouper.addNode(node));
grouper.recluster();
}).not.toThrow();

const clusters = grouper.getClusters();
expect(clusters.length).toBeGreaterThan(0);
});

it('should handle empty embedding arrays', () => {
const nodes = [
createNode('validTool', 'cat1', [1, 0.8, 0.6]),
{
value: { name: 'emptyEmbedding', category: 'cat1' },
embedding: createEmbedding([])
},
createNode('validTool2', 'cat1', [0.9, 0.7, 0.5])
];

expect(() => {
nodes.forEach(node => grouper.addNode(node));
grouper.recluster();
}).not.toThrow();

const clusters = grouper.getClusters();
expect(clusters.length).toBeGreaterThan(0);
});
});

describe('NaN values handling', () => {
it('should handle embeddings containing NaN values', () => {
const nodes = [
createNode('validTool', 'cat1', [1, 0.8, 0.6]),
createNode('nanTool', 'cat1', [NaN, 0.5, NaN]),
createNode('mixedNanTool', 'cat1', [0.7, NaN, 0.4]),
createNode('validTool2', 'cat1', [0.9, 0.7, 0.5])
];

// The main goal: should not crash when processing NaN values
expect(() => {
nodes.forEach(node => grouper.addNode(node));
grouper.recluster();
}).not.toThrow();

const clusters = grouper.getClusters();
expect(clusters.length).toBeGreaterThan(0);

// Verify that centroids exist (NaN handling may vary)
clusters.forEach(cluster => {
expect(cluster.centroid).toBeDefined();
expect(Array.isArray(cluster.centroid)).toBe(true);
});
});

it('should handle zero vectors without division by zero', () => {
const nodes = [
createNode('validTool', 'cat1', [1, 0.8, 0.6]),
createNode('zeroVector', 'cat1', [0, 0, 0]),
createNode('validTool2', 'cat1', [0.9, 0.7, 0.5])
];

expect(() => {
nodes.forEach(node => grouper.addNode(node));
grouper.recluster();
}).not.toThrow();

const clusters = grouper.getClusters();
expect(clusters.length).toBeGreaterThan(0);
});
});

describe('dimension mismatches handling', () => {
it('should handle embeddings with different dimensions', () => {
const nodes = [
createNode('tool3d', 'cat1', [1, 0.8, 0.6]),
createNode('tool2d', 'cat1', [0.9, 0.7]), // Different dimension
createNode('tool4d', 'cat1', [0.8, 0.6, 0.5, 0.4]), // Different dimension
createNode('tool3d2', 'cat1', [0.7, 0.5, 0.3])
];

// The main goal: should not crash when processing dimension mismatches
expect(() => {
nodes.forEach(node => grouper.addNode(node));
grouper.recluster();
}).not.toThrow();

const clusters = grouper.getClusters();
expect(clusters.length).toBeGreaterThan(0);

// Verify that centroids exist
clusters.forEach(cluster => {
expect(cluster.centroid).toBeDefined();
expect(cluster.centroid.length).toBeGreaterThan(0);

// Just ensure we don't have empty clusters
expect(cluster.nodes.length).toBeGreaterThan(0);
});
});
});
});
});