1616package org .springframework .data .mongodb .repository .query ;
1717
1818import java .util .ArrayList ;
19+ import java .util .Collection ;
1920import java .util .Iterator ;
2021import java .util .List ;
2122import java .util .function .Supplier ;
2627import org .springframework .data .domain .Page ;
2728import org .springframework .data .domain .Pageable ;
2829import org .springframework .data .domain .Range ;
29- import org .springframework .data .domain .Score ;
3030import org .springframework .data .domain .SearchResult ;
3131import org .springframework .data .domain .SearchResults ;
32+ import org .springframework .data .domain .Similarity ;
3233import org .springframework .data .domain .Slice ;
3334import org .springframework .data .domain .SliceImpl ;
34- import org .springframework .data .domain .Sort ;
35- import org .springframework .data .domain .Vector ;
3635import org .springframework .data .geo .Distance ;
3736import org .springframework .data .geo .GeoPage ;
3837import org .springframework .data .geo .GeoResult ;
4645import org .springframework .data .mongodb .core .ExecutableRemoveOperation .TerminatingRemove ;
4746import org .springframework .data .mongodb .core .ExecutableUpdateOperation .ExecutableUpdate ;
4847import org .springframework .data .mongodb .core .MongoOperations ;
49- import org .springframework .data .mongodb .core .aggregation .Aggregation ;
5048import org .springframework .data .mongodb .core .aggregation .AggregationOperation ;
5149import org .springframework .data .mongodb .core .aggregation .AggregationResults ;
5250import org .springframework .data .mongodb .core .aggregation .TypedAggregation ;
53- import org .springframework .data .mongodb .core .aggregation .VectorSearchOperation ;
5451import org .springframework .data .mongodb .core .query .NearQuery ;
5552import org .springframework .data .mongodb .core .query .Query ;
5653import org .springframework .data .mongodb .core .query .UpdateDefinition ;
@@ -225,7 +222,7 @@ private static boolean isListOfGeoResult(TypeInformation<?> returnType) {
225222 }
226223
227224 /**
228- * {@link MongoQueryExecution} to execute vector search
225+ * {@link MongoQueryExecution} to execute vector search.
229226 *
230227 * @author Mark Paluch
231228 * @since 5.0
@@ -235,118 +232,64 @@ class VectorSearchExecution implements MongoQueryExecution {
235232 private final MongoOperations operations ;
236233 private final MongoQueryMethod method ;
237234 private final String collectionName ;
238- private final @ Nullable Integer numCandidates ;
239- private final VectorSearchOperation .SearchType searchType ;
240- private final MongoParameterAccessor accessor ;
241- private final Class <Object > outputType ;
242- private final String path ;
235+ private final VectorSearchDelegate .QueryMetadata queryMetadata ;
236+ private final List <AggregationOperation > pipeline ;
243237
244238 public VectorSearchExecution (MongoOperations operations , MongoQueryMethod method , String collectionName ,
245- String path , @ Nullable Integer numCandidates , VectorSearchOperation .SearchType searchType ,
246- MongoParameterAccessor accessor , Class <Object > outputType ) {
239+ VectorSearchDelegate .QueryMetadata queryMetadata , MongoParameterAccessor accessor ) {
247240
248241 this .operations = operations ;
249242 this .collectionName = collectionName ;
250- this .path = path ;
251- this .numCandidates = numCandidates ;
243+ this .queryMetadata = queryMetadata ;
252244 this .method = method ;
253- this .searchType = searchType ;
254- this .accessor = accessor ;
255- this .outputType = outputType ;
245+ this .pipeline = queryMetadata .getAggregationPipeline (method , accessor );
256246 }
257247
258248 @ Override
259249 public Object execute (Query query ) {
260250
261- SearchResults <?> results = doExecuteQuery ( query );
262- return isListOfSearchResult ( method . getReturnType ()) ? results . getContent () : results ;
263- }
251+ AggregationResults <?> aggregated = operations . aggregate (
252+ TypedAggregation . newAggregation ( queryMetadata . outputType (), pipeline ), collectionName ,
253+ queryMetadata . outputType ());
264254
265- @ SuppressWarnings ("unchecked" )
266- SearchResults <Object > doExecuteQuery (Query query ) {
255+ List <?> mappedResults = aggregated .getMappedResults ();
267256
268- Vector vector = accessor .getVector ();
269- Score score = accessor .getScore ();
270- Range <Score > distance = accessor .getScoreRange ();
271- int limit ;
257+ if (isSearchResult (method .getReturnType ())) {
272258
273- if (query .isLimited ()) {
274- limit = query .getLimit ();
275- } else {
276- limit = Math .max (1 , numCandidates != null ? numCandidates / 20 : 1 );
277- }
259+ List <org .bson .Document > rawResults = aggregated .getRawResults ().getList ("results" , org .bson .Document .class );
260+ List <SearchResult <Object >> result = new ArrayList <>(mappedResults .size ());
278261
279- List <AggregationOperation > stages = new ArrayList <>();
280- VectorSearchOperation $vectorSearch = Aggregation .vectorSearch (method .getAnnotatedHint ()).path (path )
281- .vector (vector ).limit (limit );
262+ for (int i = 0 ; i < mappedResults .size (); i ++) {
263+ Document document = rawResults .get (i );
264+ SearchResult <Object > searchResult = new SearchResult <>(mappedResults .get (i ),
265+ Similarity .raw (document .getDouble ("__score__" ), queryMetadata .scoringFunction ()));
282266
283- if (numCandidates != null ) {
284- $vectorSearch = $vectorSearch .numCandidates (numCandidates );
285- }
267+ result .add (searchResult );
268+ }
286269
287- $vectorSearch = $vectorSearch .filter (query .getQueryObject ());
288- $vectorSearch = $vectorSearch .searchType (searchType );
289- $vectorSearch = $vectorSearch .withSearchScore ("__score__" );
290-
291- if (score != null ) {
292- $vectorSearch = $vectorSearch .withFilterBySore (c -> {
293- c .gt (score .getValue ());
294- });
295- } else if (distance .getLowerBound ().isBounded () || distance .getUpperBound ().isBounded ()) {
296- $vectorSearch = $vectorSearch .withFilterBySore (c -> {
297- Range .Bound <Score > lower = distance .getLowerBound ();
298- if (lower .isBounded ()) {
299- double value = lower .getValue ().get ().getValue ();
300- if (lower .isInclusive ()) {
301- c .gte (value );
302- } else {
303- c .gt (value );
304- }
305- }
306-
307- Range .Bound <Score > upper = distance .getUpperBound ();
308- if (upper .isBounded ()) {
309-
310- double value = upper .getValue ().get ().getValue ();
311- if (upper .isInclusive ()) {
312- c .lte (value );
313- } else {
314- c .lt (value );
315- }
316- }
317- });
270+ return isListOfSearchResult (method .getReturnType ()) ? result : new SearchResults <>(result );
318271 }
319272
320- stages .add ($vectorSearch );
321-
322- if (query .isSorted ()) {
323- // TODO stages.add(Aggregation.sort(query.with()));
324- } else {
325- stages .add (Aggregation .sort (Sort .Direction .DESC , "__score__" ));
326- }
327-
328- AggregationResults <Object > aggregated = operations
329- .aggregate (TypedAggregation .<Object > newAggregation (outputType , stages ), collectionName , outputType );
330-
331- List <Object > mappedResults = aggregated .getMappedResults ();
332- List <org .bson .Document > rawResults = aggregated .getRawResults ().getList ("results" , org .bson .Document .class );
333-
334- List <SearchResult <Object >> result = new ArrayList <>(mappedResults .size ());
273+ return mappedResults ;
274+ }
335275
336- for (int i = 0 ; i < mappedResults .size (); i ++) {
337- Document document = rawResults .get (i );
338- SearchResult <Object > searchResult = new SearchResult <>(mappedResults .get (i ),
339- Score .of (document .getDouble ("__score__" )));
276+ private static boolean isListOfSearchResult (TypeInformation <?> returnType ) {
340277
341- result .add (searchResult );
278+ if (!Collection .class .isAssignableFrom (returnType .getType ())) {
279+ return false ;
342280 }
343281
344- return new SearchResults <>(result );
282+ TypeInformation <?> componentType = returnType .getComponentType ();
283+ return componentType != null && SearchResult .class .equals (componentType .getType ());
345284 }
346285
347- private static boolean isListOfSearchResult (TypeInformation <?> returnType ) {
286+ private static boolean isSearchResult (TypeInformation <?> returnType ) {
348287
349- if (!returnType .getType ().equals (List .class )) {
288+ if (SearchResults .class .isAssignableFrom (returnType .getType ())) {
289+ return true ;
290+ }
291+
292+ if (!Iterable .class .isAssignableFrom (returnType .getType ())) {
350293 return false ;
351294 }
352295
0 commit comments