Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/parallax/server/executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def __init__(
)
if self.shared_state is not None:
self.shared_state.set_status(ServerState.READY.value)

# store max_sequence_length
self.max_sequence_length = max_sequence_length
self.model_path = None

@abstractmethod
def handle_input_requests(self, requests: List[Request]):
Expand Down Expand Up @@ -569,11 +573,27 @@ def _handle_raw_request(self, raw_request: Dict):
else:
prompt = convert_chat(raw_request["messages"], raw_request.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)

max_req_len = self.max_sequence_length if self.max_sequence_length is not None else 2048
input_token_num = len(prompt)
if (input_token_num >= max_req_len):
logger.warning(
f"Input token length {input_token_num} exceeds max_sequence_length {max_req_len}. Truncating input."
)
now_prompt_len = max(5, max_req_len - 10)
del prompt[now_prompt_len:]
input_token_num = len(prompt)

max_new_tokens = raw_request.get("max_tokens")
if max_new_tokens is None:
max_new_tokens = 2048
logger.debug(f"max_new_tokens from request: {max_new_tokens}")
if max_new_tokens is None or (input_token_num + max_new_tokens) >= max_req_len:
logger.warning(
f"max_new_tokens {max_new_tokens} is None or input length + max_new_tokens exceeds max_sequence_length {max_req_len}. Adjusting max_new_tokens."
)
max_new_tokens = max(0, max_req_len - input_token_num)
max_total_length = len(prompt) + max_new_tokens
logger.debug(f"Final max_new_tokens for request ID {rid}: {max_new_tokens}")
logger.debug(f"Final input token length for request ID {rid}: {input_token_num}")

lora_path = raw_request.get("lora_path")

Expand Down
Loading