Skip to content

Commit 075c5c0

Browse files
Add --queries option to specify number of queries to run
1 parent 8a21508 commit 075c5c0

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

engine/base_client/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def run_experiment(
9090
parallels: [int] = [],
9191
upload_start_idx: int = 0,
9292
upload_end_idx: int = -1,
93+
num_queries: int = -1,
9394
):
9495
execution_params = self.configurator.execution_params(
9596
distance=dataset.config.distance, vector_size=dataset.config.vector_size
@@ -161,7 +162,7 @@ def run_experiment(
161162
)
162163

163164
search_stats = searcher.search_all(
164-
dataset.config.distance, reader.read_queries()
165+
dataset.config.distance, reader.read_queries(), num_queries
165166
)
166167
# ensure we specify the client count in the results
167168
search_params["parallel"] = client_count

engine/base_client/search.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
import time
33
from multiprocessing import get_context
44
from typing import Iterable, List, Optional, Tuple
5-
import itertools
65

76
import numpy as np
87
import tqdm
9-
import os
8+
import os
109

1110
from dataset_reader.base_reader import Query
1211

@@ -62,6 +61,7 @@ def search_all(
6261
self,
6362
distance,
6463
queries: Iterable[Query],
64+
num_queries: int = -1,
6565
):
6666
parallel = self.search_params.get("parallel", 1)
6767
top = self.search_params.get("top", None)
@@ -72,13 +72,38 @@ def search_all(
7272
self.setup_search()
7373

7474
search_one = functools.partial(self.__class__._search_one, top=top)
75-
used_queries = queries
7675

76+
# Convert queries to a list for potential reuse
77+
queries_list = list(queries)
7778

79+
# Handle MAX_QUERIES environment variable
7880
if MAX_QUERIES > 0:
79-
used_queries = itertools.islice(queries, MAX_QUERIES)
81+
queries_list = queries_list[:MAX_QUERIES]
8082
print(f"Limiting queries to [0:{MAX_QUERIES-1}]")
8183

84+
# Handle num_queries parameter
85+
if num_queries > 0:
86+
# If we need more queries than available, cycle through the list
87+
if num_queries > len(queries_list) and len(queries_list) > 0:
88+
print(f"Requested {num_queries} queries but only {len(queries_list)} are available.")
89+
print(f"Extending queries by cycling through the available ones.")
90+
# Calculate how many complete cycles and remaining items we need
91+
complete_cycles = num_queries // len(queries_list)
92+
remaining = num_queries % len(queries_list)
93+
94+
# Create the extended list
95+
extended_queries = []
96+
for _ in range(complete_cycles):
97+
extended_queries.extend(queries_list)
98+
extended_queries.extend(queries_list[:remaining])
99+
100+
used_queries = extended_queries
101+
else:
102+
used_queries = queries_list[:num_queries]
103+
print(f"Using {num_queries} queries")
104+
else:
105+
used_queries = queries_list
106+
82107
if parallel == 1:
83108
start = time.perf_counter()
84109
precisions, latencies = list(

run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def run(
2626
timeout: float = 86400.0,
2727
upload_start_idx: int = 0,
2828
upload_end_idx: int = -1,
29+
queries: int = typer.Option(-1, help="Number of queries to run. If the available queries are fewer, they will be reused."),
2930
):
3031
"""
3132
Example:
@@ -68,6 +69,7 @@ def run(
6869
parallels,
6970
upload_start_idx,
7071
upload_end_idx,
72+
queries,
7173
)
7274
client.delete_client()
7375

0 commit comments

Comments
 (0)