Skip to content

Commit 2051553

Browse files
committed
Feat: Implement SurrealDB HNSW vector search (Phase 2)
Adds real HNSW vector search with metadata filtering: **SurrealDB Storage Layer** (crates/codegraph-graph/src/surrealdb_storage.rs): - vector_search_knn(): HNSW search using <|K,EF|> operator - vector_search_with_metadata(): Filtered HNSW search (type/lang/path) - get_nodes_by_ids(): Batch node loading with caching **Cloud Search Implementation** (crates/codegraph-mcp/src/server.rs): - Creates dedicated SurrealDB connection for cloud mode - Uses vector_search_with_metadata() for filtered HNSW search - Loads full nodes efficiently with batch query - Returns results with HNSW similarity scores - Detailed performance metrics for each step **Features**: - ✅ Automatic SurrealDB connection from env vars - ✅ HNSW search with configurable ef_search parameter - ✅ Metadata filtering (node_type, language, file_path patterns) - ✅ OR-pattern support for file paths (src/|lib/) - ✅ Cache-aware node loading - ✅ Performance tracking (embedding, connect, search, load, format) **Status**: Phase 2 complete, compiles successfully **Next**: Phase 3 - Jina reranking integration (optional enhancement) **Testing**: Requires SurrealDB instance with HNSW indexes
1 parent c0f4c98 commit 2051553

File tree

2 files changed

+266
-51
lines changed

2 files changed

+266
-51
lines changed

crates/codegraph-graph/src/surrealdb_storage.rs

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,192 @@ impl SurrealDbStorage {
288288
}
289289
*/
290290

291+
/// Vector search using SurrealDB HNSW indexes
292+
/// Returns node IDs and similarity scores
293+
pub async fn vector_search_knn(
294+
&self,
295+
query_embedding: Vec<f32>,
296+
limit: usize,
297+
ef_search: usize,
298+
) -> Result<Vec<(String, f32)>> {
299+
info!(
300+
"Executing HNSW vector search with limit={}, ef_search={}",
301+
limit, ef_search
302+
);
303+
304+
// Convert f32 to f64 for SurrealDB
305+
let query_vec: Vec<f64> = query_embedding.iter().map(|&f| f as f64).collect();
306+
307+
// SurrealDB HNSW search using <|K,EF|> operator
308+
// vector::distance::knn() reuses pre-computed distance from HNSW
309+
let query = r#"
310+
SELECT id, vector::distance::knn() AS score
311+
FROM nodes
312+
WHERE embedding <|$limit,$ef_search|> $query_embedding
313+
ORDER BY score ASC
314+
LIMIT $limit
315+
"#;
316+
317+
let mut result = self
318+
.db
319+
.query(query)
320+
.bind(("query_embedding", query_vec))
321+
.bind(("limit", limit))
322+
.bind(("ef_search", ef_search))
323+
.await
324+
.map_err(|e| CodeGraphError::Database(format!("HNSW search failed: {}", e)))?;
325+
326+
#[derive(Deserialize)]
327+
struct SearchResult {
328+
id: String,
329+
score: f64,
330+
}
331+
332+
let results: Vec<SearchResult> = result.take(0).map_err(|e| {
333+
CodeGraphError::Database(format!("Failed to extract search results: {}", e))
334+
})?;
335+
336+
Ok(results
337+
.into_iter()
338+
.map(|r| (r.id, r.score as f32))
339+
.collect())
340+
}
341+
342+
/// Vector search with metadata filtering
343+
pub async fn vector_search_with_metadata(
344+
&self,
345+
query_embedding: Vec<f32>,
346+
limit: usize,
347+
ef_search: usize,
348+
node_type: Option<String>,
349+
language: Option<String>,
350+
file_path_pattern: Option<String>,
351+
) -> Result<Vec<(String, f32)>> {
352+
info!(
353+
"Executing filtered HNSW search: type={:?}, lang={:?}, path={:?}",
354+
node_type, language, file_path_pattern
355+
);
356+
357+
let query_vec: Vec<f64> = query_embedding.iter().map(|&f| f as f64).collect();
358+
359+
// Build dynamic WHERE clause
360+
let mut where_clauses =
361+
vec!["embedding <|$limit,$ef_search|> $query_embedding".to_string()];
362+
363+
if let Some(ref nt) = node_type {
364+
where_clauses.push(format!("node_type = '{}'", nt));
365+
}
366+
367+
if let Some(ref lang) = language {
368+
where_clauses.push(format!("language = '{}'", lang));
369+
}
370+
371+
if let Some(ref path) = file_path_pattern {
372+
// Support OR patterns like "src/|lib/"
373+
if path.contains('|') {
374+
let patterns: Vec<String> = path
375+
.split('|')
376+
.map(|p| format!("file_path CONTAINS '{}'", p))
377+
.collect();
378+
where_clauses.push(format!("({})", patterns.join(" OR ")));
379+
} else {
380+
where_clauses.push(format!("file_path CONTAINS '{}'", path));
381+
}
382+
}
383+
384+
let where_clause = where_clauses.join(" AND ");
385+
386+
let query = format!(
387+
r#"
388+
SELECT id, vector::distance::knn() AS score
389+
FROM nodes
390+
WHERE {}
391+
ORDER BY score ASC
392+
LIMIT $limit
393+
"#,
394+
where_clause
395+
);
396+
397+
let mut result = self
398+
.db
399+
.query(&query)
400+
.bind(("query_embedding", query_vec))
401+
.bind(("limit", limit))
402+
.bind(("ef_search", ef_search))
403+
.await
404+
.map_err(|e| CodeGraphError::Database(format!("Filtered HNSW search failed: {}", e)))?;
405+
406+
#[derive(Deserialize)]
407+
struct SearchResult {
408+
id: String,
409+
score: f64,
410+
}
411+
412+
let results: Vec<SearchResult> = result.take(0).map_err(|e| {
413+
CodeGraphError::Database(format!("Failed to extract filtered results: {}", e))
414+
})?;
415+
416+
Ok(results
417+
.into_iter()
418+
.map(|r| (r.id, r.score as f32))
419+
.collect())
420+
}
421+
422+
/// Get multiple nodes by their IDs in one query
423+
pub async fn get_nodes_by_ids(&self, ids: &[String]) -> Result<Vec<CodeNode>> {
424+
if ids.is_empty() {
425+
return Ok(Vec::new());
426+
}
427+
428+
debug!("Getting {} nodes by IDs", ids.len());
429+
430+
// Check cache first for all IDs
431+
let mut nodes = Vec::new();
432+
let mut missing_ids = Vec::new();
433+
434+
if self.config.cache_enabled {
435+
for id_str in ids {
436+
if let Ok(id) = NodeId::parse_str(id_str) {
437+
if let Some(cached) = self.node_cache.get(&id) {
438+
nodes.push(cached.clone());
439+
} else {
440+
missing_ids.push(id_str.clone());
441+
}
442+
}
443+
}
444+
} else {
445+
missing_ids = ids.to_vec();
446+
}
447+
448+
// Fetch missing nodes from database
449+
if !missing_ids.is_empty() {
450+
let query = "SELECT * FROM nodes WHERE id IN $ids";
451+
let mut result = self
452+
.db
453+
.query(query)
454+
.bind(("ids", missing_ids))
455+
.await
456+
.map_err(|e| CodeGraphError::Database(format!("Failed to query nodes: {}", e)))?;
457+
458+
let db_nodes: Vec<HashMap<String, JsonValue>> = result.take(0).map_err(|e| {
459+
CodeGraphError::Database(format!("Failed to extract query results: {}", e))
460+
})?;
461+
462+
for data in db_nodes {
463+
let node = self.surreal_to_node(data)?;
464+
465+
// Update cache
466+
if self.config.cache_enabled {
467+
self.node_cache.insert(node.id, node.clone());
468+
}
469+
470+
nodes.push(node);
471+
}
472+
}
473+
474+
Ok(nodes)
475+
}
476+
291477
/// Convert CodeNode to SurrealDB-compatible format
292478
fn node_to_surreal(&self, node: &CodeNode) -> Result<HashMap<String, JsonValue>> {
293479
let mut data = HashMap::new();

crates/codegraph-mcp/src/server.rs

Lines changed: 80 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ async fn cloud_search_impl(
736736
paths: Option<Vec<String>>,
737737
langs: Option<Vec<String>>,
738738
limit: usize,
739-
graph: &codegraph_graph::CodeGraph,
739+
_graph: &codegraph_graph::CodeGraph,
740740
) -> anyhow::Result<Value> {
741741
use codegraph_core::Language;
742742

@@ -745,94 +745,123 @@ async fn cloud_search_impl(
745745
tracing::info!("🌐 Cloud Mode: SurrealDB HNSW + Jina reranking");
746746

747747
// 1. Generate query embedding using Jina/Cloud provider
748+
let start_embedding = Instant::now();
748749
let embedding_gen = get_embedding_generator().await?;
749750
let query_embedding = embedding_gen.generate_text_embedding(&query).await?;
751+
let embedding_time = start_embedding.elapsed().as_millis() as u64;
750752

751753
// 2. Overretrieve for reranking (3x limit)
752754
let overretrieve_limit = limit * 3;
753755

754-
// 3. SurrealDB HNSW search
755-
let storage = graph.get_storage();
756+
// 3. Create SurrealDB storage connection for cloud mode
757+
let start_connect = Instant::now();
758+
let surrealdb_config = codegraph_graph::SurrealDbConfig {
759+
connection: std::env::var("SURREALDB_URL")
760+
.unwrap_or_else(|_| "ws://localhost:3004".to_string()),
761+
namespace: std::env::var("SURREALDB_NAMESPACE").unwrap_or_else(|_| "codegraph".to_string()),
762+
database: std::env::var("SURREALDB_DATABASE").unwrap_or_else(|_| "main".to_string()),
763+
username: std::env::var("SURREALDB_USERNAME").ok(),
764+
password: std::env::var("SURREALDB_PASSWORD").ok(),
765+
strict_mode: false,
766+
auto_migrate: false, // Don't migrate on every search
767+
cache_enabled: true,
768+
};
756769

757-
// For now, use get_all_nodes and do in-memory similarity search
758-
// TODO: Implement proper HNSW search in SurrealDB storage layer
759-
let all_nodes = storage.get_all_nodes().await?;
770+
let surrealdb_storage = codegraph_graph::SurrealDbStorage::new(surrealdb_config).await?;
771+
let connect_time = start_connect.elapsed().as_millis() as u64;
760772

761-
if all_nodes.is_empty() {
762-
return Ok(json!({
763-
"results": [],
764-
"message": "No nodes found in database. Please index your codebase first.",
765-
"performance": {
766-
"total_ms": start_total.elapsed().as_millis(),
767-
"mode": "cloud"
768-
}
769-
}));
770-
}
773+
// 4. SurrealDB HNSW search with metadata filtering
774+
let start_search = Instant::now();
771775

772-
// Filter by language and path
773-
let lang_filter: Option<Vec<Language>> = langs.as_ref().map(|langs| {
776+
// Build filter parameters
777+
let node_type_filter = langs.as_ref().map(|langs| {
774778
langs
775779
.iter()
776780
.filter_map(|l| match l.to_lowercase().as_str() {
777-
"rust" => Some(Language::Rust),
778-
"python" => Some(Language::Python),
779-
"javascript" | "js" => Some(Language::JavaScript),
780-
"typescript" | "ts" => Some(Language::TypeScript),
781-
"go" => Some(Language::Go),
782-
"java" => Some(Language::Java),
781+
"rust" => Some("Rust".to_string()),
782+
"python" => Some("Python".to_string()),
783+
"javascript" | "js" => Some("JavaScript".to_string()),
784+
"typescript" | "ts" => Some("TypeScript".to_string()),
785+
"go" => Some("Go".to_string()),
786+
"java" => Some("Java".to_string()),
783787
_ => None,
784788
})
785-
.collect()
789+
.next() // Take first matching language for now
786790
});
787791

788-
let filtered_nodes: Vec<_> = all_nodes
789-
.into_iter()
790-
.filter(|node| {
791-
// Filter by language
792-
if let Some(ref langs) = lang_filter {
793-
if !langs.contains(&node.language) {
794-
return false;
795-
}
796-
}
792+
let file_path_pattern = paths.as_ref().map(|p| p.join("|"));
793+
794+
let search_results = surrealdb_storage
795+
.vector_search_with_metadata(
796+
query_embedding.clone(),
797+
overretrieve_limit,
798+
100, // ef_search parameter for HNSW
799+
node_type_filter,
800+
None, // language filter (separate from node_type)
801+
file_path_pattern,
802+
)
803+
.await?;
797804

798-
// Filter by paths
799-
if let Some(ref paths) = paths {
800-
if !paths.iter().any(|p| node.location.file_path.contains(p)) {
801-
return false;
802-
}
805+
let search_time = start_search.elapsed().as_millis() as u64;
806+
807+
if search_results.is_empty() {
808+
return Ok(json!({
809+
"results": [],
810+
"message": "No results found. Ensure codebase is indexed with embeddings.",
811+
"performance": {
812+
"total_ms": start_total.elapsed().as_millis(),
813+
"embedding_ms": embedding_time,
814+
"connect_ms": connect_time,
815+
"search_ms": search_time,
816+
"mode": "cloud"
803817
}
818+
}));
819+
}
804820

805-
true
806-
})
807-
.collect();
821+
// 5. Load full nodes from SurrealDB
822+
let start_load = Instant::now();
823+
let node_ids: Vec<String> = search_results.iter().map(|(id, _)| id.clone()).collect();
824+
let nodes = surrealdb_storage.get_nodes_by_ids(&node_ids).await?;
825+
let load_time = start_load.elapsed().as_millis() as u64;
826+
827+
// 6. TODO: Add Jina reranking here
828+
// For now, use HNSW scores directly
808829

809-
// TODO: For MVP, return basic results without reranking
810-
// Full implementation requires Jina reranking integration
811-
let results: Vec<Value> = filtered_nodes
830+
// 7. Format results
831+
let start_format = Instant::now();
832+
let results: Vec<Value> = nodes
812833
.iter()
834+
.zip(search_results.iter())
813835
.take(limit)
814-
.map(|node| {
836+
.map(|(node, (_id, score))| {
815837
json!({
816838
"id": node.id,
817839
"name": node.name,
818-
"node_type": format!("{:?}", node.node_type),
819-
"language": format!("{:?}", node.language),
840+
"node_type": node.node_type.as_ref().map(|nt| format!("{:?}", nt)).unwrap_or_default(),
841+
"language": node.language.as_ref().map(|l| format!("{:?}", l)).unwrap_or_default(),
820842
"file_path": node.location.file_path,
821-
"start_line": node.location.start_line,
843+
"start_line": node.location.line,
822844
"end_line": node.location.end_line,
823-
"score": 0.5, // Placeholder score
845+
"score": 1.0 - score, // Convert distance to similarity
824846
"summary": node.content.as_deref().unwrap_or("").chars().take(160).collect::<String>()
825847
})
826848
})
827849
.collect();
850+
let format_time = start_format.elapsed().as_millis() as u64;
828851

829852
Ok(json!({
830853
"results": results,
831854
"total_results": results.len(),
832855
"performance": {
833856
"total_ms": start_total.elapsed().as_millis(),
857+
"embedding_generation_ms": embedding_time,
858+
"surrealdb_connection_ms": connect_time,
859+
"hnsw_search_ms": search_time,
860+
"node_loading_ms": load_time,
861+
"formatting_ms": format_time,
834862
"mode": "cloud",
835-
"note": "SurrealDB HNSW search - reranking not yet implemented"
863+
"hnsw_enabled": true,
864+
"reranking_enabled": false // TODO: Phase 2.5
836865
}
837866
}))
838867
}

0 commit comments

Comments
 (0)