diff --git a/src/parallax/p2p/message_util.py b/src/parallax/p2p/message_util.py index 1181c988..725e0d9d 100644 --- a/src/parallax/p2p/message_util.py +++ b/src/parallax/p2p/message_util.py @@ -50,6 +50,10 @@ def request_to_proto( if request.next_token_id is not None: proto_req.next_token_id = request.next_token_id + # Add token_logit if available + if hasattr(request, "token_logit") and request.token_logit is not None: + proto_req.token_logit = request.token_logit + forward_request.reqs.append(proto_req) return forward_request @@ -86,6 +90,11 @@ def proto_to_request( sampling_params = proto_to_sampling_params(proto_req.sampling_params) + # Extract token_logit if present + token_logit = None + if proto_req.HasField("token_logit"): + token_logit = proto_req.token_logit + request = IntermediateRequest( request_id=proto_req.rid, current_position=current_position, @@ -96,6 +105,7 @@ def proto_to_request( next_token_id=next_token_id, sampling_params=sampling_params, lora_path=proto_req.lora_path if proto_req.lora_path != "" else None, + token_logit=token_logit, ) requests.append(request) diff --git a/src/parallax/p2p/proto/forward.proto b/src/parallax/p2p/proto/forward.proto index 4854ab57..bac1ac37 100644 --- a/src/parallax/p2p/proto/forward.proto +++ b/src/parallax/p2p/proto/forward.proto @@ -36,6 +36,7 @@ message Req { int32 next_token_id = 6; bytes hidden_states = 7; string lora_path = 8; + optional float token_logit = 9; // Logit value for the sampled token } message SamplingParams { diff --git a/src/parallax/p2p/proto/forward_pb2.py b/src/parallax/p2p/proto/forward_pb2.py index 9e4e3f69..c60b0994 100644 --- a/src/parallax/p2p/proto/forward_pb2.py +++ b/src/parallax/p2p/proto/forward_pb2.py @@ -24,15 +24,15 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xc7\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/parallax/p2p/proto/forward.proto\x12\x08gradient\"Z\n\x0e\x46orwardRequest\x12+\n\x0c\x66orward_mode\x18\x01 \x01(\x0e\x32\x15.gradient.ForwardMode\x12\x1b\n\x04reqs\x18\x02 \x03(\x0b\x32\r.gradient.Req\"\x11\n\x0f\x46orwardResponse\"+\n\x0c\x41\x62ortRequest\x12\x1b\n\x04reqs\x18\x01 \x03(\x0b\x32\r.gradient.Req\"\x0f\n\rAbortResponse\"\xf1\x01\n\x03Req\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x15\n\routput_length\x18\x02 \x01(\x05\x12\x15\n\rrouting_table\x18\x03 \x03(\t\x12\x11\n\tinput_ids\x18\x04 \x03(\x05\x12\x31\n\x0fsampling_params\x18\x05 \x01(\x0b\x32\x18.gradient.SamplingParams\x12\x15\n\rnext_token_id\x18\x06 \x01(\x05\x12\x15\n\rhidden_states\x18\x07 \x01(\x0c\x12\x11\n\tlora_path\x18\x08 \x01(\t\x12\x18\n\x0btoken_logit\x18\t \x01(\x02H\x00\x88\x01\x01\x42\x0e\n\x0c_token_logit\"\xa7\x02\n\x0eSamplingParams\x12\x16\n\x0emax_new_tokens\x18\x01 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x02 \x01(\x05\x12\x13\n\x0btemperature\x18\x03 \x01(\x02\x12\r\n\x05top_p\x18\x04 \x01(\x02\x12\r\n\x05min_p\x18\x05 \x01(\x02\x12\r\n\x05top_k\x18\x06 \x01(\x05\x12\x16\n\x0estop_token_ids\x18\x07 \x03(\x05\x12\x12\n\nignore_eos\x18\x08 \x01(\x08\x12\x11\n\tstop_strs\x18\t \x03(\t\x12\x1a\n\x12repetition_penalty\x18\n \x01(\x02\x12\x18\n\x10presence_penalty\x18\x0b \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x0c \x01(\x02\x12\x13\n\x0bjson_schema\x18\r \x01(\t*0\n\x0b\x46orwardMode\x12\n\n\x06\x45XTEND\x10\x00\x12\n\n\x06\x44\x45\x43ODE\x10\x01\x12\t\n\x05MIXED\x10\x02\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.parallax.p2p.proto.forward_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_FORWARDMODE']._serialized_start=723 - _globals['_FORWARDMODE']._serialized_end=771 + _globals['_FORWARDMODE']._serialized_start=765 + _globals['_FORWARDMODE']._serialized_end=813 _globals['_FORWARDREQUEST']._serialized_start=50 _globals['_FORWARDREQUEST']._serialized_end=140 _globals['_FORWARDRESPONSE']._serialized_start=142 @@ -42,7 +42,7 @@ _globals['_ABORTRESPONSE']._serialized_start=206 _globals['_ABORTRESPONSE']._serialized_end=221 _globals['_REQ']._serialized_start=224 - _globals['_REQ']._serialized_end=423 - _globals['_SAMPLINGPARAMS']._serialized_start=426 - _globals['_SAMPLINGPARAMS']._serialized_end=721 + _globals['_REQ']._serialized_end=465 + _globals['_SAMPLINGPARAMS']._serialized_start=468 + _globals['_SAMPLINGPARAMS']._serialized_end=763 # @@protoc_insertion_point(module_scope) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index a092ac1e..81d3f7b5 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -369,7 +369,15 @@ def prepare_next_batch_requests( hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] pre_length += 1 - next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) + # Get logit for this request if available + token_logit = None + if self.is_last_peer and hasattr(self, "_latest_token_logits"): + if self._latest_token_logits is not None and len(self._latest_token_logits) > i: + token_logit = self._latest_token_logits[i] + + next_req = self._prepare_next_single_request( + src_request, hidden_state_for_req, token_logit + ) batched_requests.append(next_req) else: batched_requests = None @@ -576,6 +584,7 @@ def _handle_raw_request(self, raw_request: Dict): max_total_length = len(prompt) + max_new_tokens lora_path = raw_request.get("lora_path") + return_logits = raw_request.get("return_logits", False) # Get return_logits parameter raw_sampling_params = raw_request.get("sampling_params") if raw_sampling_params is None: @@ -600,6 +609,7 @@ def _handle_raw_request(self, raw_request: Dict): max_new_tokens=max_new_tokens, max_total_length=max_total_length, lora_path=lora_path, + return_logits=return_logits, ) if "routing_table" in raw_request: req.routing_table = raw_request["routing_table"] @@ -633,7 +643,9 @@ def _notify_http_request_error(self, raw_request: Optional[Dict], error: Excepti except Exception: # pragma: no cover - best effort notification logger.debug("Failed to send error notification to HTTP handler", exc_info=True) - def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request: + def _prepare_next_single_request( + self, request: Request, hidden_states: Any, token_logit: Optional[float] = None + ) -> Request: """Handle request state changes both inter and intra peers. This function prepares the request object to be sent to the *next* peer in the @@ -642,6 +654,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Args: request: The request that was just processed by this peer. hidden_states: The output hidden_states/output_ids from the model for this request. + token_logit: The logit value for the sampled token (optional). Returns: A new Request object ready to be sent to the next destination. @@ -662,6 +675,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, + token_logit=token_logit, ) if self.is_last_peer: # Last peer decodes a token and sends it back to the first peer. @@ -680,6 +694,7 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> next_token_id=next_token_id, routing_table=request.routing_table, lora_path=request.lora_path, + token_logit=token_logit, ) # This peer is the first or an intermediate peer. if self.is_first_peer: diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 67c14ca7..2caefead 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -198,6 +198,9 @@ def __init__( f"KVCacheManager ready; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}" ) + # Store latest sampled token logit values (not full distribution) + self._latest_token_logits = None + def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" if not requests: @@ -253,6 +256,12 @@ def handle_input_requests(self, requests: List[Request]): req_dict["eos"] = True if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True + + # Add logit value for the sampled token (if requested and available) + if hasattr(original_req, "return_logits") and original_req.return_logits: + if hasattr(req, "token_logit") and req.token_logit is not None: + req_dict["logits"] = req.token_logit + if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) else: @@ -330,10 +339,34 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Process last peer: need additional sampling + detokenization if return_decoded_tokens: sampling_info = SamplingBatchInfo.from_reqs(requests) - return mx.array( + + # For MLX, hidden_states at last shard is already logits (after lm_head) + # hidden_states shape: [batch_size, seq_len, vocab_size] + token_ids = mx.array( self.model_shard.logits_to_tokens(hidden_states, lengths, sampling_info) ) + # Extract logit values for sampled tokens + try: + # Get last position logits for each request + batch_logits = [] + for i, req in enumerate(requests): + if lengths[i] > 0: + # Get logit at last position + last_idx = int(lengths[i]) - 1 + last_logits = hidden_states[i, last_idx, :] # [vocab_size] + # Extract logit for the sampled token + token_id = int(token_ids[i]) + logit_value = float(last_logits[token_id]) + batch_logits.append(logit_value) + + self._latest_token_logits = batch_logits if batch_logits else None + except Exception as e: + logger.debug(f"Failed to extract token logits: {e}") + self._latest_token_logits = None + + return token_ids + return hidden_states def _release_request(self, rid: str): diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 7637765c..48f5bb95 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -159,6 +159,9 @@ def __init__( self.tp_group = self.model_runner.tp_group self.tp_cpu_group = self.tp_group.cpu_group + # Store latest sampled token logits (not full distribution) + self._latest_token_logits = None + def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" @@ -299,6 +302,12 @@ def handle_input_requests(self, requests: List[Request]): req_dict["eos"] = True if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: req_dict["length"] = True + + # Add logit value for the sampled token (if requested and available) + if hasattr(original_req, "return_logits") and original_req.return_logits: + if hasattr(req, "token_logit") and req.token_logit is not None: + req_dict["logits"] = req.token_logit + if hasattr(self, "send_to_ipc_socket"): self.send_to_ipc_socket.send_pyobj(req_dict) else: @@ -349,6 +358,17 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: if return_decoded_tokens: # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) + + # Extract logits for the sampled tokens + if hasattr(logits_output, "next_token_logits"): + # Get logits for sampled tokens + real_logits = logits_output.next_token_logits[ + torch.arange(len(next_token_ids)), next_token_ids + ] + self._latest_token_logits = real_logits.cpu().float().tolist() + else: + self._latest_token_logits = None + return next_token_ids else: # Intermediate peer: return hidden states for next peer diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index de1e9573..d8583cfd 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -22,7 +22,7 @@ import uuid from dataclasses import dataclass, field from http import HTTPStatus -from typing import Dict, Optional +from typing import Dict, List, Optional import fastapi import uvicorn @@ -87,6 +87,9 @@ class HTTPRequestInfo: error_message: Optional[str] = None error_type: Optional[str] = None error_status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR + # logits support + return_logits: bool = False # Whether to return logits + logits_list: List = field(default_factory=list) # Store logits for each token class HTTPHandler: @@ -128,6 +131,7 @@ def create_request(self, request: Dict): rid = request["rid"] stream = request.get("stream", False) model = request.get("model", "default") + return_logits = request.get("return_logits", False) # Check if logits requested chat_object = "chat.completion.chunk" if stream else "chat.completion" detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap) create_time = time.time() @@ -140,6 +144,7 @@ def create_request(self, request: Dict): create_time=create_time, update_time=update_time, detokenizer=detokenizer, + return_logits=return_logits, ) if stream: request_info.token_queue = asyncio.Queue() @@ -151,6 +156,11 @@ def release_request(self, rid: str): def send_request(self, request: Dict): """Sends the request to model executor using IPC.""" + # Ensure return_logits is included in the request sent to executor + rid = request.get("rid") + if rid and rid in self.processing_requests: + request_info = self.processing_requests[rid] + request["return_logits"] = request_info.return_logits self.send_to_executor.send_pyobj(request) def abort_request(self, request_id: str): @@ -280,6 +290,9 @@ def generate_non_stream_response(self, rid): "reasoning_content": None, "tool_calls": None, } + # Add logits if requested + if request_info.return_logits: + choice["logits"] = request_info.logits_list return response async def _handle_executor_error(self, rid: str, recv_dict: Dict): @@ -331,6 +344,10 @@ async def _handle_loop(self): request_info.detokenizer.add_token(next_token_id) output = request_info.detokenizer.last_segment + # Store logits if requested + if request_info.return_logits and "logits" in recv_dict: + request_info.logits_list.append(recv_dict["logits"]) + is_finished = recv_dict.get("eos", False) or recv_dict.get("length", False) # Only process and send non-EOS tokens diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f76accdc..841f1b37 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -158,6 +158,7 @@ def __init__( max_total_length: int = 1024, status: RequestStatus = RequestStatus.PREFILLING, lora_path: Optional[str] = None, + return_logits: bool = False, ): if not prompt and not input_ids: raise ValueError("prompt or input_ids cannot be empty.") @@ -170,6 +171,7 @@ def __init__( lora_path=lora_path, ) self.prompt = prompt + self.return_logits = return_logits if max_new_tokens < 1: raise ValueError("max_new_tokens must be at least 1.") @@ -262,6 +264,7 @@ def __init__( routing_table: Optional[List[str]] = [], sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, + token_logit: Optional[float] = None, ): super().__init__( request_id=request_id, @@ -283,6 +286,7 @@ def __init__( self.current_position = current_position self.hidden_states = hidden_states self.next_token_id = next_token_id + self.token_logit = token_logit @property def input_length(self) -> int: @@ -301,6 +305,7 @@ def from_initial_request( initial_request: InitialRequest, hidden_states: Optional[Any] = None, lora_path: Optional[str] = None, + token_logit: Optional[float] = None, ) -> "IntermediateRequest": """Convert an InitialRequest to an IntermediateRequest. @@ -333,6 +338,7 @@ def from_initial_request( sampling_params=initial_request.sampling_params, routing_table=initial_request.routing_table, lora_path=lora_path, + token_logit=token_logit, ) @classmethod @@ -341,6 +347,7 @@ def from_intermediate_request( old_request: "IntermediateRequest", new_hidden_states: Any, lora_path: Optional[str] = None, + token_logit: Optional[float] = None, ) -> "IntermediateRequest": """ Creates a new IntermediateRequest from an old one, with updated hidden states. @@ -356,6 +363,7 @@ def from_intermediate_request( routing_table=old_request.routing_table, sampling_params=old_request.sampling_params, lora_path=lora_path, + token_logit=token_logit, ) def __repr__(self):