Skip to content

Commit bb41033

Browse files
authored
[cleanup] move arg parsing into a separate function (#178)
Pure refactor with no code functionality change.
1 parent b38b3f5 commit bb41033

File tree

1 file changed

+134
-134
lines changed

1 file changed

+134
-134
lines changed

benchmarks/benchmark_serving.py

Lines changed: 134 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -706,136 +706,7 @@ def sample_warmup_requests(requests):
706706
break
707707

708708

709-
def main(args: argparse.Namespace):
710-
print(args)
711-
random.seed(args.seed)
712-
np.random.seed(args.seed)
713-
714-
model_id = args.model
715-
tokenizer_id = args.tokenizer
716-
use_hf_tokenizer = args.use_hf_tokenizer
717-
718-
prefill_quota = AsyncCounter(init_value=3)
719-
active_req_quota = AsyncCounter(init_value=450)
720-
721-
api_url = f"{args.server}:{args.port}"
722-
723-
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
724-
if tokenizer == "test" or args.dataset == "test":
725-
input_requests = mock_requests(
726-
args.total_mock_requests
727-
) # e.g. [("AB", 2, "AB", 3)]
728-
else:
729-
dataset = []
730-
if args.dataset == "openorca":
731-
dataset = load_openorca_dataset_pkl(args.dataset_path)
732-
elif args.dataset == "sharegpt":
733-
dataset = load_sharegpt_dataset(
734-
args.dataset_path,
735-
args.conversation_starter,
736-
)
737-
738-
# A given args.max_output_length value is the max generation step,
739-
# when the args.max_output_length is default to None, the sample's golden
740-
# output length will be used to decide the generation step.
741-
input_requests = sample_requests(
742-
dataset=dataset,
743-
tokenizer=tokenizer,
744-
num_requests=args.num_prompts,
745-
max_output_length=args.max_output_length,
746-
)
747-
748-
warmup_requests = None
749-
if args.warmup_mode == "full":
750-
warmup_requests = input_requests
751-
elif args.warmup_mode == "sampled":
752-
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
753-
754-
if warmup_requests:
755-
print(f"Warmup (mode: {args.warmup_mode}) is starting.")
756-
_, _ = asyncio.run(
757-
benchmark(
758-
api_url=api_url,
759-
tokenizer=tokenizer,
760-
input_requests=warmup_requests,
761-
request_rate=args.request_rate,
762-
disable_tqdm=args.disable_tqdm,
763-
prefill_quota=prefill_quota,
764-
active_req_quota=active_req_quota,
765-
is_warmup=True,
766-
)
767-
)
768-
print(f"Warmup (mode: {args.warmup_mode}) has completed.")
769-
770-
# TODO: Replace this with warmup complete signal once supported.
771-
# Wait for server completely warmup before running the benchmark.
772-
time.sleep(5)
773-
774-
benchmark_result, request_outputs = asyncio.run(
775-
benchmark(
776-
api_url=api_url,
777-
tokenizer=tokenizer,
778-
input_requests=input_requests,
779-
request_rate=args.request_rate,
780-
disable_tqdm=args.disable_tqdm,
781-
prefill_quota=prefill_quota,
782-
active_req_quota=active_req_quota,
783-
)
784-
)
785-
786-
# Process output
787-
output = [output.to_dict() for output in request_outputs]
788-
if args.run_eval:
789-
eval_json = eval_accuracy(output)
790-
791-
# Save config and results to json
792-
if args.save_result:
793-
# dimensions values are strings
794-
dimensions_json = {}
795-
# metrics values are numerical
796-
metrics_json = {}
797-
798-
# Setup
799-
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
800-
dimensions_json["date"] = current_dt
801-
dimensions_json["model_id"] = model_id
802-
dimensions_json["tokenizer_id"] = tokenizer_id
803-
if args.additional_metadata_metrics_to_save is not None:
804-
dimensions_json = {
805-
**dimensions_json,
806-
**json.loads(args.additional_metadata_metrics_to_save),
807-
}
808-
metrics_json["num_prompts"] = args.num_prompts
809-
810-
# Traffic
811-
metrics_json["request_rate"] = args.request_rate
812-
metrics_json = {**metrics_json, **benchmark_result}
813-
if args.run_eval:
814-
metrics_json = {**metrics_json, **eval_json}
815-
816-
final_json = {}
817-
final_json["metrics"] = metrics_json
818-
final_json["dimensions"] = dimensions_json
819-
820-
# Save to file
821-
base_model_id = model_id.split("/")[-1]
822-
file_name = (
823-
f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
824-
)
825-
with open(file_name, "w", encoding="utf-8") as outfile:
826-
json.dump(final_json, outfile)
827-
828-
if args.save_request_outputs:
829-
file_path = args.request_outputs_file_path
830-
with open(file_path, "w", encoding="utf-8") as output_file:
831-
json.dump(
832-
output,
833-
output_file,
834-
indent=4,
835-
)
836-
837-
838-
if __name__ == "__main__":
709+
def parse_args() -> argparse.Namespace:
839710
parser = argparse.ArgumentParser(
840711
description="Benchmark the online serving throughput."
841712
)
@@ -909,7 +780,6 @@ def main(args: argparse.Namespace):
909780
default=150,
910781
help="The maximum number of mock requests to send for benchmark testing.",
911782
)
912-
913783
parser.add_argument(
914784
"--max-output-length",
915785
type=int,
@@ -926,7 +796,6 @@ def main(args: argparse.Namespace):
926796
"the output length of the golden dataset would be passed."
927797
),
928798
)
929-
930799
parser.add_argument("--seed", type=int, default=0)
931800
parser.add_argument(
932801
"--disable-tqdm",
@@ -977,7 +846,138 @@ def main(args: argparse.Namespace):
977846
choices=["human", "gpt", "both"],
978847
help="What entity should be the one starting the conversations.",
979848
)
849+
return parser.parse_args()
980850

981-
parsed_args = parser.parse_args()
851+
852+
def main(args: argparse.Namespace):
853+
print(args)
854+
random.seed(args.seed)
855+
np.random.seed(args.seed)
856+
857+
model_id = args.model
858+
tokenizer_id = args.tokenizer
859+
use_hf_tokenizer = args.use_hf_tokenizer
860+
861+
prefill_quota = AsyncCounter(init_value=3)
862+
active_req_quota = AsyncCounter(init_value=450)
863+
864+
api_url = f"{args.server}:{args.port}"
865+
866+
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
867+
if tokenizer == "test" or args.dataset == "test":
868+
input_requests = mock_requests(
869+
args.total_mock_requests
870+
) # e.g. [("AB", 2, "AB", 3)]
871+
else:
872+
dataset = []
873+
if args.dataset == "openorca":
874+
dataset = load_openorca_dataset_pkl(args.dataset_path)
875+
elif args.dataset == "sharegpt":
876+
dataset = load_sharegpt_dataset(
877+
args.dataset_path,
878+
args.conversation_starter,
879+
)
880+
881+
# A given args.max_output_length value is the max generation step,
882+
# when the args.max_output_length is default to None, the sample's golden
883+
# output length will be used to decide the generation step.
884+
input_requests = sample_requests(
885+
dataset=dataset,
886+
tokenizer=tokenizer,
887+
num_requests=args.num_prompts,
888+
max_output_length=args.max_output_length,
889+
)
890+
891+
warmup_requests = None
892+
if args.warmup_mode == "full":
893+
warmup_requests = input_requests
894+
elif args.warmup_mode == "sampled":
895+
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
896+
897+
if warmup_requests:
898+
print(f"Warmup (mode: {args.warmup_mode}) is starting.")
899+
_, _ = asyncio.run(
900+
benchmark(
901+
api_url=api_url,
902+
tokenizer=tokenizer,
903+
input_requests=warmup_requests,
904+
request_rate=args.request_rate,
905+
disable_tqdm=args.disable_tqdm,
906+
prefill_quota=prefill_quota,
907+
active_req_quota=active_req_quota,
908+
is_warmup=True,
909+
)
910+
)
911+
print(f"Warmup (mode: {args.warmup_mode}) has completed.")
912+
913+
# TODO: Replace this with warmup complete signal once supported.
914+
# Wait for server completely warmup before running the benchmark.
915+
time.sleep(5)
916+
917+
benchmark_result, request_outputs = asyncio.run(
918+
benchmark(
919+
api_url=api_url,
920+
tokenizer=tokenizer,
921+
input_requests=input_requests,
922+
request_rate=args.request_rate,
923+
disable_tqdm=args.disable_tqdm,
924+
prefill_quota=prefill_quota,
925+
active_req_quota=active_req_quota,
926+
)
927+
)
928+
929+
# Process output
930+
output = [output.to_dict() for output in request_outputs]
931+
if args.run_eval:
932+
eval_json = eval_accuracy(output)
933+
934+
# Save config and results to json
935+
if args.save_result:
936+
# dimensions values are strings
937+
dimensions_json = {}
938+
# metrics values are numerical
939+
metrics_json = {}
940+
941+
# Setup
942+
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
943+
dimensions_json["date"] = current_dt
944+
dimensions_json["model_id"] = model_id
945+
dimensions_json["tokenizer_id"] = tokenizer_id
946+
if args.additional_metadata_metrics_to_save is not None:
947+
dimensions_json = {
948+
**dimensions_json,
949+
**json.loads(args.additional_metadata_metrics_to_save),
950+
}
951+
metrics_json["num_prompts"] = args.num_prompts
952+
953+
# Traffic
954+
metrics_json["request_rate"] = args.request_rate
955+
metrics_json = {**metrics_json, **benchmark_result}
956+
if args.run_eval:
957+
metrics_json = {**metrics_json, **eval_json}
958+
959+
final_json = {}
960+
final_json["metrics"] = metrics_json
961+
final_json["dimensions"] = dimensions_json
962+
963+
# Save to file
964+
base_model_id = model_id.split("/")[-1]
965+
file_name = (
966+
f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
967+
)
968+
with open(file_name, "w", encoding="utf-8") as outfile:
969+
json.dump(final_json, outfile)
970+
971+
if args.save_request_outputs:
972+
file_path = args.request_outputs_file_path
973+
with open(file_path, "w", encoding="utf-8") as output_file:
974+
json.dump(
975+
output,
976+
output_file,
977+
indent=4,
978+
)
979+
980+
981+
if __name__ == "__main__":
982982
gc.disable()
983-
main(parsed_args)
983+
main(parse_args())

0 commit comments

Comments
 (0)