Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/parallax/p2p/message_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def request_to_proto(
if request.next_token_id is not None:
proto_req.next_token_id = request.next_token_id

# Add token_logit if available
if hasattr(request, "token_logit") and request.token_logit is not None:
proto_req.token_logit = request.token_logit

forward_request.reqs.append(proto_req)

return forward_request
Expand Down Expand Up @@ -86,6 +90,11 @@ def proto_to_request(

sampling_params = proto_to_sampling_params(proto_req.sampling_params)

# Extract token_logit if present
token_logit = None
if proto_req.HasField("token_logit"):
token_logit = proto_req.token_logit

request = IntermediateRequest(
request_id=proto_req.rid,
current_position=current_position,
Expand All @@ -96,6 +105,7 @@ def proto_to_request(
next_token_id=next_token_id,
sampling_params=sampling_params,
lora_path=proto_req.lora_path if proto_req.lora_path != "" else None,
token_logit=token_logit,
)

requests.append(request)
Expand Down
1 change: 1 addition & 0 deletions src/parallax/p2p/proto/forward.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ message Req {
int32 next_token_id = 6;
bytes hidden_states = 7;
string lora_path = 8;
optional float token_logit = 9; // Logit value for the sampled token
}

message SamplingParams {
Expand Down
12 changes: 6 additions & 6 deletions src/parallax/p2p/proto/forward_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 17 additions & 2 deletions src/parallax/server/executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,15 @@ def prepare_next_batch_requests(
hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :]
pre_length += 1

next_req = self._prepare_next_single_request(src_request, hidden_state_for_req)
# Get logit for this request if available
token_logit = None
if self.is_last_peer and hasattr(self, "_latest_token_logits"):
if self._latest_token_logits is not None and len(self._latest_token_logits) > i:
token_logit = self._latest_token_logits[i]

next_req = self._prepare_next_single_request(
src_request, hidden_state_for_req, token_logit
)
batched_requests.append(next_req)
else:
batched_requests = None
Expand Down Expand Up @@ -576,6 +584,7 @@ def _handle_raw_request(self, raw_request: Dict):
max_total_length = len(prompt) + max_new_tokens

lora_path = raw_request.get("lora_path")
return_logits = raw_request.get("return_logits", False) # Get return_logits parameter

raw_sampling_params = raw_request.get("sampling_params")
if raw_sampling_params is None:
Expand All @@ -600,6 +609,7 @@ def _handle_raw_request(self, raw_request: Dict):
max_new_tokens=max_new_tokens,
max_total_length=max_total_length,
lora_path=lora_path,
return_logits=return_logits,
)
if "routing_table" in raw_request:
req.routing_table = raw_request["routing_table"]
Expand Down Expand Up @@ -633,7 +643,9 @@ def _notify_http_request_error(self, raw_request: Optional[Dict], error: Excepti
except Exception: # pragma: no cover - best effort notification
logger.debug("Failed to send error notification to HTTP handler", exc_info=True)

def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request:
def _prepare_next_single_request(
self, request: Request, hidden_states: Any, token_logit: Optional[float] = None
) -> Request:
"""Handle request state changes both inter and intra peers.

This function prepares the request object to be sent to the *next* peer in the
Expand All @@ -642,6 +654,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
Args:
request: The request that was just processed by this peer.
hidden_states: The output hidden_states/output_ids from the model for this request.
token_logit: The logit value for the sampled token (optional).

Returns:
A new Request object ready to be sent to the next destination.
Expand All @@ -662,6 +675,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
next_token_id=next_token_id,
routing_table=request.routing_table,
lora_path=request.lora_path,
token_logit=token_logit,
)
if self.is_last_peer:
# Last peer decodes a token and sends it back to the first peer.
Expand All @@ -680,6 +694,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
next_token_id=next_token_id,
routing_table=request.routing_table,
lora_path=request.lora_path,
token_logit=token_logit,
)
# This peer is the first or an intermediate peer.
if self.is_first_peer:
Expand Down
35 changes: 34 additions & 1 deletion src/parallax/server/executor/mlx_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def __init__(
f"KVCacheManager ready; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}"
)

# Store latest sampled token logit values (not full distribution)
self._latest_token_logits = None

def handle_input_requests(self, requests: List[Request]):
"""Update requests states and status in scheduler and cache manager."""
if not requests:
Expand Down Expand Up @@ -253,6 +256,12 @@ def handle_input_requests(self, requests: List[Request]):
req_dict["eos"] = True
if original_req.status == RequestStatus.FINISHED_MAX_LENGTH:
req_dict["length"] = True

# Add logit value for the sampled token (if requested and available)
if hasattr(original_req, "return_logits") and original_req.return_logits:
if hasattr(req, "token_logit") and req.token_logit is not None:
req_dict["logits"] = req.token_logit

if hasattr(self, "send_to_ipc_socket"):
self.send_to_ipc_socket.send_pyobj(req_dict)
else:
Expand Down Expand Up @@ -330,10 +339,34 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
# Process last peer: need additional sampling + detokenization
if return_decoded_tokens:
sampling_info = SamplingBatchInfo.from_reqs(requests)
return mx.array(

# For MLX, hidden_states at last shard is already logits (after lm_head)
# hidden_states shape: [batch_size, seq_len, vocab_size]
token_ids = mx.array(
self.model_shard.logits_to_tokens(hidden_states, lengths, sampling_info)
)

# Extract logit values for sampled tokens
try:
# Get last position logits for each request
batch_logits = []
for i, req in enumerate(requests):
if lengths[i] > 0:
# Get logit at last position
last_idx = int(lengths[i]) - 1
last_logits = hidden_states[i, last_idx, :] # [vocab_size]
# Extract logit for the sampled token
token_id = int(token_ids[i])
logit_value = float(last_logits[token_id])
batch_logits.append(logit_value)

self._latest_token_logits = batch_logits if batch_logits else None
except Exception as e:
logger.debug(f"Failed to extract token logits: {e}")
self._latest_token_logits = None

return token_ids

return hidden_states

def _release_request(self, rid: str):
Expand Down
20 changes: 20 additions & 0 deletions src/parallax/server/executor/sglang_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def __init__(
self.tp_group = self.model_runner.tp_group
self.tp_cpu_group = self.tp_group.cpu_group

# Store latest sampled token logits (not full distribution)
self._latest_token_logits = None

def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"

Expand Down Expand Up @@ -299,6 +302,12 @@ def handle_input_requests(self, requests: List[Request]):
req_dict["eos"] = True
if original_req.status == RequestStatus.FINISHED_MAX_LENGTH:
req_dict["length"] = True

# Add logit value for the sampled token (if requested and available)
if hasattr(original_req, "return_logits") and original_req.return_logits:
if hasattr(req, "token_logit") and req.token_logit is not None:
req_dict["logits"] = req.token_logit

if hasattr(self, "send_to_ipc_socket"):
self.send_to_ipc_socket.send_pyobj(req_dict)
else:
Expand Down Expand Up @@ -349,6 +358,17 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
if return_decoded_tokens:
# Last peer: sample and return token IDs
next_token_ids = self.model_runner.sample(logits_output, forward_batch)

# Extract logits for the sampled tokens
if hasattr(logits_output, "next_token_logits"):
# Get logits for sampled tokens
real_logits = logits_output.next_token_logits[
torch.arange(len(next_token_ids)), next_token_ids
]
self._latest_token_logits = real_logits.cpu().float().tolist()
else:
self._latest_token_logits = None

return next_token_ids
else:
# Intermediate peer: return hidden states for next peer
Expand Down
19 changes: 18 additions & 1 deletion src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import uuid
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Dict, Optional
from typing import Dict, List, Optional

import fastapi
import uvicorn
Expand Down Expand Up @@ -87,6 +87,9 @@ class HTTPRequestInfo:
error_message: Optional[str] = None
error_type: Optional[str] = None
error_status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR
# logits support
return_logits: bool = False # Whether to return logits
logits_list: List = field(default_factory=list) # Store logits for each token


class HTTPHandler:
Expand Down Expand Up @@ -128,6 +131,7 @@ def create_request(self, request: Dict):
rid = request["rid"]
stream = request.get("stream", False)
model = request.get("model", "default")
return_logits = request.get("return_logits", False) # Check if logits requested
chat_object = "chat.completion.chunk" if stream else "chat.completion"
detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap)
create_time = time.time()
Expand All @@ -140,6 +144,7 @@ def create_request(self, request: Dict):
create_time=create_time,
update_time=update_time,
detokenizer=detokenizer,
return_logits=return_logits,
)
if stream:
request_info.token_queue = asyncio.Queue()
Expand All @@ -151,6 +156,11 @@ def release_request(self, rid: str):

def send_request(self, request: Dict):
"""Sends the request to model executor using IPC."""
# Ensure return_logits is included in the request sent to executor
rid = request.get("rid")
if rid and rid in self.processing_requests:
request_info = self.processing_requests[rid]
request["return_logits"] = request_info.return_logits
self.send_to_executor.send_pyobj(request)

def abort_request(self, request_id: str):
Expand Down Expand Up @@ -280,6 +290,9 @@ def generate_non_stream_response(self, rid):
"reasoning_content": None,
"tool_calls": None,
}
# Add logits if requested
if request_info.return_logits:
choice["logits"] = request_info.logits_list
return response

async def _handle_executor_error(self, rid: str, recv_dict: Dict):
Expand Down Expand Up @@ -331,6 +344,10 @@ async def _handle_loop(self):
request_info.detokenizer.add_token(next_token_id)
output = request_info.detokenizer.last_segment

# Store logits if requested
if request_info.return_logits and "logits" in recv_dict:
request_info.logits_list.append(recv_dict["logits"])

is_finished = recv_dict.get("eos", False) or recv_dict.get("length", False)

# Only process and send non-EOS tokens
Expand Down
8 changes: 8 additions & 0 deletions src/parallax/server/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
max_total_length: int = 1024,
status: RequestStatus = RequestStatus.PREFILLING,
lora_path: Optional[str] = None,
return_logits: bool = False,
):
if not prompt and not input_ids:
raise ValueError("prompt or input_ids cannot be empty.")
Expand All @@ -170,6 +171,7 @@ def __init__(
lora_path=lora_path,
)
self.prompt = prompt
self.return_logits = return_logits

if max_new_tokens < 1:
raise ValueError("max_new_tokens must be at least 1.")
Expand Down Expand Up @@ -262,6 +264,7 @@ def __init__(
routing_table: Optional[List[str]] = [],
sampling_params: Optional[SamplingParams] = None,
lora_path: Optional[str] = None,
token_logit: Optional[float] = None,
):
super().__init__(
request_id=request_id,
Expand All @@ -283,6 +286,7 @@ def __init__(
self.current_position = current_position
self.hidden_states = hidden_states
self.next_token_id = next_token_id
self.token_logit = token_logit

@property
def input_length(self) -> int:
Expand All @@ -301,6 +305,7 @@ def from_initial_request(
initial_request: InitialRequest,
hidden_states: Optional[Any] = None,
lora_path: Optional[str] = None,
token_logit: Optional[float] = None,
) -> "IntermediateRequest":
"""Convert an InitialRequest to an IntermediateRequest.

Expand Down Expand Up @@ -333,6 +338,7 @@ def from_initial_request(
sampling_params=initial_request.sampling_params,
routing_table=initial_request.routing_table,
lora_path=lora_path,
token_logit=token_logit,
)

@classmethod
Expand All @@ -341,6 +347,7 @@ def from_intermediate_request(
old_request: "IntermediateRequest",
new_hidden_states: Any,
lora_path: Optional[str] = None,
token_logit: Optional[float] = None,
) -> "IntermediateRequest":
"""
Creates a new IntermediateRequest from an old one, with updated hidden states.
Expand All @@ -356,6 +363,7 @@ def from_intermediate_request(
routing_table=old_request.routing_table,
sampling_params=old_request.sampling_params,
lora_path=lora_path,
token_logit=token_logit,
)

def __repr__(self):
Expand Down