11import functools
22import time
3- from multiprocessing import get_context , Barrier , Process , Queue
3+ from multiprocessing import get_context
44from typing import Iterable , List , Optional , Tuple
55import itertools
66
@@ -65,10 +65,6 @@ def search_all(
6565 ):
6666 parallel = self .search_params .get ("parallel" , 1 )
6767 top = self .search_params .get ("top" , None )
68-
69- # Convert queries to a list to calculate its length
70- queries = list (queries ) # This allows us to calculate len(queries)
71-
7268 # setup_search may require initialized client
7369 self .init_client (
7470 self .host , distance , self .connection_params , self .search_params
@@ -84,56 +80,31 @@ def search_all(
8480 print (f"Limiting queries to [0:{ MAX_QUERIES - 1 } ]" )
8581
8682 if parallel == 1 :
87- # Single-threaded execution
8883 start = time .perf_counter ()
89-
90- results = [search_one (query ) for query in tqdm .tqdm (queries )]
91- total_time = time .perf_counter () - start
92-
84+ precisions , latencies = list (
85+ zip (* [search_one (query ) for query in tqdm .tqdm (used_queries )])
86+ )
9387 else :
94- # Dynamically calculate chunk size
95- chunk_size = max (1 , len (queries ) // parallel )
96- query_chunks = list (chunked_iterable (queries , chunk_size ))
88+ ctx = get_context (self .get_mp_start_method ())
9789
98- # Function to be executed by each worker process
99- def worker_function (chunk , result_queue ):
100- self .__class__ .init_client (
90+ with ctx .Pool (
91+ processes = parallel ,
92+ initializer = self .__class__ .init_client ,
93+ initargs = (
10194 self .host ,
10295 distance ,
10396 self .connection_params ,
10497 self .search_params ,
98+ ),
99+ ) as pool :
100+ if parallel > 10 :
101+ time .sleep (15 ) # Wait for all processes to start
102+ start = time .perf_counter ()
103+ precisions , latencies = list (
104+ zip (* pool .imap_unordered (search_one , iterable = tqdm .tqdm (used_queries )))
105105 )
106- self .setup_search ()
107- results = process_chunk (chunk , search_one )
108- result_queue .put (results )
109-
110- # Create a queue to collect results
111- result_queue = Queue ()
112-
113- # Create and start worker processes
114- processes = []
115- for chunk in query_chunks :
116- process = Process (target = worker_function , args = (chunk , result_queue ))
117- processes .append (process )
118- process .start ()
119-
120- # Start measuring time for the critical work
121- start = time .perf_counter ()
122106
123- # Collect results from all worker processes
124- results = []
125- for _ in processes :
126- results .extend (result_queue .get ())
127-
128- # Wait for all worker processes to finish
129- for process in processes :
130- process .join ()
131-
132- # Stop measuring time for the critical work
133- total_time = time .perf_counter () - start
134-
135- # Extract precisions and latencies (outside the timed section)
136- precisions , latencies = zip (* results )
107+ total_time = time .perf_counter () - start
137108
138109 self .__class__ .delete_client ()
139110
@@ -161,20 +132,3 @@ def post_search(self):
161132 @classmethod
162133 def delete_client (cls ):
163134 pass
164-
165-
166- def chunked_iterable (iterable , size ):
167- """Yield successive chunks of a given size from an iterable."""
168- it = iter (iterable )
169- while chunk := list (itertools .islice (it , size )):
170- yield chunk
171-
172-
173- def process_chunk (chunk , search_one ):
174- """Process a chunk of queries using the search_one function."""
175- return [search_one (query ) for query in chunk ]
176-
177-
178- def process_chunk_wrapper (chunk , search_one ):
179- """Wrapper to process a chunk of queries."""
180- return process_chunk (chunk , search_one )
0 commit comments