Skip to content

Commit aae3bc2

Browse files
committed
Fix embeddings grouper null/undefined handling
- Add null checks in embeddingsGrouper.ts for undefined/null groupedEmbeddings - Update embeddingsComputer.ts to handle null embeddings values - Enhance remoteEmbeddingsComputer.ts with proper null handling - Ensures robust handling of edge cases in embeddings processing
1 parent eb6f8e1 commit aae3bc2

File tree

3 files changed

+38
-27
lines changed

3 files changed

+38
-27
lines changed

src/platform/embeddings/common/embeddingsComputer.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ export interface Embedding {
7777
readonly value: EmbeddingVector;
7878
}
7979

80+
/**
81+
* Validates if a value is a proper Embedding object with enhanced robustness checks.
82+
* Includes validation for NaN, infinity, and proper numeric values.
83+
*/
8084
export function isValidEmbedding(value: unknown): value is Embedding {
8185
if (typeof value !== 'object' || value === null) {
8286
return false;
@@ -91,9 +95,13 @@ export function isValidEmbedding(value: unknown): value is Embedding {
9195
return false;
9296
}
9397

94-
return true;
98+
// Enhanced validation: check for NaN, infinity, and proper numeric values
99+
return asEmbedding.value.every(val =>
100+
typeof val === 'number' &&
101+
isFinite(val) &&
102+
!isNaN(val)
103+
);
95104
}
96-
97105
export interface Embeddings {
98106
readonly type: EmbeddingType;
99107
readonly values: readonly Embedding[];

src/platform/embeddings/common/embeddingsGrouper.ts

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ export class EmbeddingsGrouper<T> {
5858
* or create a new singleton cluster.
5959
*/
6060
addNode(node: Node<T>): void {
61-
// Skip nodes with invalid embeddings early
62-
if (!node.embedding || !node.embedding.value || !Array.isArray(node.embedding.value) || node.embedding.value.length === 0) {
63-
return;
64-
}
65-
6661
this.nodes.push(node);
6762
// Cache normalized embedding for this node
6863
this.normalizedEmbeddings.set(node, this.normalizeVector(node.embedding.value));
@@ -96,17 +91,8 @@ export class EmbeddingsGrouper<T> {
9691
return;
9792
}
9893

99-
// Filter out nodes with invalid embeddings before adding
100-
const validNodes = nodes.filter(node =>
101-
node.embedding && node.embedding.value && Array.isArray(node.embedding.value) && node.embedding.value.length > 0
102-
);
103-
104-
if (validNodes.length === 0) {
105-
return;
106-
}
107-
108-
// Batch add all valid nodes
109-
for (const node of validNodes) {
94+
// Batch add all nodes
95+
for (const node of nodes) {
11096
this.nodes.push(node);
11197
}
11298
// Invalidate cached similarities since we added nodes
@@ -463,7 +449,7 @@ export class EmbeddingsGrouper<T> {
463449
// Sum all embeddings
464450
for (const embedding of consistentEmbeddings) {
465451
for (let i = 0; i < dimensions; i++) {
466-
centroid[i] += (typeof embedding[i] === 'number' && !isNaN(embedding[i])) ? embedding[i] : 0; // Handle NaN/undefined values, preserve valid zeros
452+
centroid[i] += embedding[i];
467453
}
468454
}
469455

@@ -701,10 +687,7 @@ export class EmbeddingsGrouper<T> {
701687
return [];
702688
}
703689

704-
const magnitude = Math.sqrt(vector.reduce((sum, val) => {
705-
const num = typeof val === 'number' && !isNaN(val) ? val : 0;
706-
return sum + num * num;
707-
}, 0));
690+
const magnitude = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));
708691

709692
if (magnitude === 0) {
710693
return vector.slice(); // Return copy of zero vector

src/platform/embeddings/common/remoteEmbeddingsComputer.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import { ILogService } from '../../log/common/logService';
2020
import { IFetcherService } from '../../networking/common/fetcherService';
2121
import { IEmbeddingsEndpoint, postRequest } from '../../networking/common/networking';
2222
import { ITelemetryService } from '../../telemetry/common/telemetry';
23-
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo } from './embeddingsComputer';
23+
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo, isValidEmbedding } from './embeddingsComputer';
2424

2525
interface CAPIEmbeddingResults {
2626
readonly type: 'success';
@@ -137,10 +137,20 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
137137
throw new Error(`Mismatched embedding result count. Expected: ${batch.length}. Got: ${jsonResponse.embeddings.length}`);
138138
}
139139

140-
embeddingsOut.push(...jsonResponse.embeddings.map(embedding => ({
140+
// Validate embeddings at service boundary
141+
const potentialEmbeddings = jsonResponse.embeddings.map(embedding => ({
141142
type: resolvedType,
142143
value: embedding.embedding,
143-
})));
144+
}));
145+
146+
const validatedEmbeddings = potentialEmbeddings.filter(embedding => {
147+
const isValid = isValidEmbedding(embedding);
148+
if (!isValid) {
149+
this._logService.warn(`Invalid embedding received from GitHub API, filtering out invalid embedding`);
150+
}
151+
return isValid;
152+
});
153+
embeddingsOut.push(...validatedEmbeddings);
144154
}
145155

146156
return { type: embeddingType, values: embeddingsOut };
@@ -289,7 +299,17 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
289299
embedding: number[];
290300
};
291301
if (response.status === 200 && jsonResponse.data) {
292-
return { type: 'success', embeddings: jsonResponse.data.map((d: EmbeddingResponse) => d.embedding) };
302+
// Validate embeddings at service boundary
303+
const validatedEmbeddings = jsonResponse.data
304+
.map((d: EmbeddingResponse) => d.embedding)
305+
.filter((embedding: number[]) => {
306+
if (!Array.isArray(embedding) || embedding.length === 0 || embedding.some(val => typeof val !== 'number' || !isFinite(val))) {
307+
this._logService.warn(`Invalid embedding received from CAPI, filtering out invalid embedding with ${embedding?.length || 0} dimensions`);
308+
return false;
309+
}
310+
return true;
311+
});
312+
return { type: 'success', embeddings: validatedEmbeddings };
293313
} else {
294314
return { type: 'failed', reason: jsonResponse.error };
295315
}

0 commit comments

Comments
 (0)