Skip to content

Commit 7a72bdd

Browse files
committed
Fix critical bugs in embeddings calculations
- Fix || 0 replacing legitimate zero values with proper type/NaN checks - Fix dimension mismatch causing incorrect centroid division - Add comprehensive NaN handling in normalizeVector magnitude calculation - Ensure all embeddings have consistent dimensions before processing Addresses Copilot feedback on mathematical correctness and edge cases.
1 parent 66f8ee6 commit 7a72bdd

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/platform/embeddings/common/embeddingsGrouper.ts

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -424,32 +424,38 @@ export class EmbeddingsGrouper<T> {
424424
return [];
425425
}
426426

427-
// Filter out invalid embeddings
427+
// Filter out invalid embeddings and ensure consistent dimensions
428428
const validEmbeddings = embeddings.filter(embedding => embedding && Array.isArray(embedding) && embedding.length > 0);
429429

430430
if (validEmbeddings.length === 0) {
431431
return [];
432432
}
433433

434-
if (validEmbeddings.length === 1) {
435-
return [...validEmbeddings[0]]; // Copy to avoid mutations
434+
// Ensure all embeddings have the same dimensions as the first valid one
435+
const expectedDimensions = validEmbeddings[0].length;
436+
const consistentEmbeddings = validEmbeddings.filter(embedding => embedding.length === expectedDimensions);
437+
438+
if (consistentEmbeddings.length === 0) {
439+
return [];
440+
}
441+
442+
if (consistentEmbeddings.length === 1) {
443+
return [...consistentEmbeddings[0]]; // Copy to avoid mutations
436444
}
437445

438-
const dimensions = validEmbeddings[0].length;
446+
const dimensions = expectedDimensions;
439447
const centroid = new Array(dimensions).fill(0);
440448

441449
// Sum all embeddings
442-
for (const embedding of validEmbeddings) {
443-
// Additional safety check in case embedding has different dimensions
444-
const embeddingLength = Math.min(embedding.length, dimensions);
445-
for (let i = 0; i < embeddingLength; i++) {
450+
for (const embedding of consistentEmbeddings) {
451+
for (let i = 0; i < dimensions; i++) {
446452
centroid[i] += (typeof embedding[i] === 'number' && !isNaN(embedding[i])) ? embedding[i] : 0; // Handle NaN/undefined values, preserve valid zeros
447453
}
448454
}
449455

450456
// Divide by count to get mean
451457
for (let i = 0; i < dimensions; i++) {
452-
centroid[i] /= validEmbeddings.length;
458+
centroid[i] /= consistentEmbeddings.length;
453459
}
454460

455461
// L2 normalize the centroid

0 commit comments

Comments
 (0)