From 92a3838e7de026dca74adfce352a5fa6620451cc Mon Sep 17 00:00:00 2001 From: wasamtc Date: Fri, 5 Dec 2025 08:52:21 +0000 Subject: [PATCH] fix bug of max_sequence_length --- src/parallax/server/executor/base_executor.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index a092ac1e..1191e961 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -176,6 +176,10 @@ def __init__( ) if self.shared_state is not None: self.shared_state.set_status(ServerState.READY.value) + + # store max_sequence_length + self.max_sequence_length = max_sequence_length + self.model_path = None @abstractmethod def handle_input_requests(self, requests: List[Request]): @@ -569,11 +573,27 @@ def _handle_raw_request(self, raw_request: Dict): else: prompt = convert_chat(raw_request["messages"], raw_request.get("role_mapping")) prompt = self.tokenizer.encode(prompt) + + max_req_len = self.max_sequence_length if self.max_sequence_length is not None else 2048 + input_token_num = len(prompt) + if (input_token_num >= max_req_len): + logger.warning( + f"Input token length {input_token_num} exceeds max_sequence_length {max_req_len}. Truncating input." + ) + now_prompt_len = max(5, max_req_len - 10) + del prompt[now_prompt_len:] + input_token_num = len(prompt) max_new_tokens = raw_request.get("max_tokens") - if max_new_tokens is None: - max_new_tokens = 2048 + logger.debug(f"max_new_tokens from request: {max_new_tokens}") + if max_new_tokens is None or (input_token_num + max_new_tokens) >= max_req_len: + logger.warning( + f"max_new_tokens {max_new_tokens} is None or input length + max_new_tokens exceeds max_sequence_length {max_req_len}. Adjusting max_new_tokens." + ) + max_new_tokens = max(0, max_req_len - input_token_num) max_total_length = len(prompt) + max_new_tokens + logger.debug(f"Final max_new_tokens for request ID {rid}: {max_new_tokens}") + logger.debug(f"Final input token length for request ID {rid}: {input_token_num}") lora_path = raw_request.get("lora_path")