diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index a092ac1e..03429bd3 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -177,6 +177,10 @@ def __init__( if self.shared_state is not None: self.shared_state.set_status(ServerState.READY.value) + # store max_sequence_length + self.max_sequence_length = max_sequence_length + self.model_path = None + @abstractmethod def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" @@ -570,10 +574,26 @@ def _handle_raw_request(self, raw_request: Dict): prompt = convert_chat(raw_request["messages"], raw_request.get("role_mapping")) prompt = self.tokenizer.encode(prompt) + max_req_len = self.max_sequence_length if self.max_sequence_length is not None else 2048 + input_token_num = len(prompt) + if input_token_num >= max_req_len: + logger.warning( + f"Input token length {input_token_num} exceeds max_sequence_length {max_req_len}. Truncating input." + ) + now_prompt_len = max(5, max_req_len - 10) + del prompt[now_prompt_len:] + input_token_num = len(prompt) + max_new_tokens = raw_request.get("max_tokens") - if max_new_tokens is None: - max_new_tokens = 2048 + logger.debug(f"max_new_tokens from request: {max_new_tokens}") + if max_new_tokens is None or (input_token_num + max_new_tokens) >= max_req_len: + logger.warning( + f"max_new_tokens {max_new_tokens} is None or input length + max_new_tokens exceeds max_sequence_length {max_req_len}. Adjusting max_new_tokens." + ) + max_new_tokens = max(0, max_req_len - input_token_num) max_total_length = len(prompt) + max_new_tokens + logger.debug(f"Final max_new_tokens for request ID {rid}: {max_new_tokens}") + logger.debug(f"Final input token length for request ID {rid}: {input_token_num}") lora_path = raw_request.get("lora_path")