22import time
33from multiprocessing import get_context
44from typing import Iterable , List , Optional , Tuple
5- import itertools
65
76import numpy as np
87import tqdm
9- import os
8+ import os
109
1110from dataset_reader .base_reader import Query
1211
@@ -62,6 +61,7 @@ def search_all(
6261 self ,
6362 distance ,
6463 queries : Iterable [Query ],
64+ num_queries : int = - 1 ,
6565 ):
6666 parallel = self .search_params .get ("parallel" , 1 )
6767 top = self .search_params .get ("top" , None )
@@ -72,13 +72,38 @@ def search_all(
7272 self .setup_search ()
7373
7474 search_one = functools .partial (self .__class__ ._search_one , top = top )
75- used_queries = queries
7675
76+ # Convert queries to a list for potential reuse
77+ queries_list = list (queries )
7778
79+ # Handle MAX_QUERIES environment variable
7880 if MAX_QUERIES > 0 :
79- used_queries = itertools . islice ( queries , MAX_QUERIES )
81+ queries_list = queries_list [: MAX_QUERIES ]
8082 print (f"Limiting queries to [0:{ MAX_QUERIES - 1 } ]" )
8183
84+ # Handle num_queries parameter
85+ if num_queries > 0 :
86+ # If we need more queries than available, cycle through the list
87+ if num_queries > len (queries_list ) and len (queries_list ) > 0 :
88+ print (f"Requested { num_queries } queries but only { len (queries_list )} are available." )
89+ print (f"Extending queries by cycling through the available ones." )
90+ # Calculate how many complete cycles and remaining items we need
91+ complete_cycles = num_queries // len (queries_list )
92+ remaining = num_queries % len (queries_list )
93+
94+ # Create the extended list
95+ extended_queries = []
96+ for _ in range (complete_cycles ):
97+ extended_queries .extend (queries_list )
98+ extended_queries .extend (queries_list [:remaining ])
99+
100+ used_queries = extended_queries
101+ else :
102+ used_queries = queries_list [:num_queries ]
103+ print (f"Using { num_queries } queries" )
104+ else :
105+ used_queries = queries_list
106+
82107 if parallel == 1 :
83108 start = time .perf_counter ()
84109 precisions , latencies = list (
0 commit comments