|
2 | 2 | import time |
3 | 3 | from multiprocessing import get_context |
4 | 4 | from typing import Iterable, List, Optional, Tuple |
| 5 | +from itertools import islice |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 | import tqdm |
@@ -112,22 +113,31 @@ def search_all( |
112 | 113 | else: |
113 | 114 | ctx = get_context(self.get_mp_start_method()) |
114 | 115 |
|
115 | | - with ctx.Pool( |
116 | | - processes=parallel, |
117 | | - initializer=self.__class__.init_client, |
118 | | - initargs=( |
| 116 | + def process_initializer(): |
| 117 | + """Initialize each process before starting the search.""" |
| 118 | + self.__class__.init_client( |
119 | 119 | self.host, |
120 | 120 | distance, |
121 | 121 | self.connection_params, |
122 | 122 | self.search_params, |
123 | | - ), |
| 123 | + ) |
| 124 | + self.setup_search() |
| 125 | + |
| 126 | + # Dynamically chunk the generator |
| 127 | + query_chunks = list(chunked_iterable(used_queries, max(1, len(used_queries) // parallel))) |
| 128 | + |
| 129 | + with ctx.Pool( |
| 130 | + processes=parallel, |
| 131 | + initializer=process_initializer, |
124 | 132 | ) as pool: |
125 | 133 | if parallel > 10: |
126 | 134 | time.sleep(15) # Wait for all processes to start |
127 | 135 | start = time.perf_counter() |
128 | | - precisions, latencies = list( |
129 | | - zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(used_queries))) |
| 136 | + results = pool.starmap( |
| 137 | + process_chunk, |
| 138 | + [(chunk, search_one) for chunk in query_chunks], |
130 | 139 | ) |
| 140 | + precisions, latencies = zip(*[result for chunk in results for result in chunk]) |
131 | 141 |
|
132 | 142 | total_time = time.perf_counter() - start |
133 | 143 |
|
@@ -157,3 +167,15 @@ def post_search(self): |
157 | 167 | @classmethod |
158 | 168 | def delete_client(cls): |
159 | 169 | pass |
| 170 | + |
| 171 | + |
| 172 | +def chunked_iterable(iterable, size): |
| 173 | + """Yield successive chunks of a given size from an iterable.""" |
| 174 | + it = iter(iterable) |
| 175 | + while chunk := list(islice(it, size)): |
| 176 | + yield chunk |
| 177 | + |
| 178 | + |
| 179 | +def process_chunk(chunk, search_one): |
| 180 | + """Process a chunk of queries using the search_one function.""" |
| 181 | + return [search_one(query) for query in chunk] |
0 commit comments