Skip to content

Commit fbd13bb

Browse files
[Feat] Add structured generation OpenAI API (#1114)
1 parent ad8b4ae commit fbd13bb

File tree

5 files changed

+125
-9
lines changed

5 files changed

+125
-9
lines changed

lightllm/server/api_models.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import time
2+
import uuid
23

34
from pydantic import BaseModel, Field, field_validator
4-
from typing import Dict, List, Optional, Union, Literal
5-
import uuid
5+
from typing import Any, Dict, List, Optional, Union, Literal
66

77

88
class ImageURL(BaseModel):
@@ -52,6 +52,21 @@ class StreamOptions(BaseModel):
5252
include_usage: Optional[bool] = False
5353

5454

55+
class JsonSchemaResponseFormat(BaseModel):
56+
name: str
57+
description: Optional[str] = None
58+
# schema is the field in openai but that causes conflicts with pydantic so
59+
# instead use json_schema with an alias
60+
json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema")
61+
strict: Optional[bool] = None
62+
63+
64+
class ResponseFormat(BaseModel):
65+
# type must be "json_schema", "json_object", or "text"
66+
type: Literal["text", "json_object", "json_schema"]
67+
json_schema: Optional[JsonSchemaResponseFormat] = None
68+
69+
5570
class CompletionRequest(BaseModel):
5671
model: str
5772
# prompt: string or tokens
@@ -71,6 +86,14 @@ class CompletionRequest(BaseModel):
7186
best_of: Optional[int] = 1
7287
logit_bias: Optional[Dict[str, float]] = None
7388
user: Optional[str] = None
89+
response_format: Optional[ResponseFormat] = Field(
90+
default=None,
91+
description=(
92+
"Similar to chat completion, this parameter specifies the format "
93+
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
94+
", or {'type': 'text' } is supported."
95+
),
96+
)
7497

7598
# Additional parameters supported by LightLLM
7699
do_sample: Optional[bool] = False
@@ -94,7 +117,14 @@ class ChatCompletionRequest(BaseModel):
94117
frequency_penalty: Optional[float] = 0.0
95118
logit_bias: Optional[Dict[str, float]] = None
96119
user: Optional[str] = None
97-
response_format: Optional[Dict] = None
120+
response_format: Optional[ResponseFormat] = Field(
121+
default=None,
122+
description=(
123+
"Similar to chat completion, this parameter specifies the format "
124+
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
125+
", or {'type': 'text' } is supported."
126+
),
127+
)
98128

99129
# OpenAI Adaptive parameters for tool call
100130
tools: Optional[List[Tool]] = Field(default=None, examples=[None])

lightllm/server/api_openai.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,17 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
175175
"best_of": request.n,
176176
"add_special_tokens": False,
177177
}
178+
179+
# Structured output handling
178180
if request.response_format:
179-
obj = request.response_format.get("schema")
180-
if obj:
181-
# guided_json takes str instead of dict obj
182-
sampling_params_dict["guided_json"] = json.dumps(obj)
181+
if request.response_format.type == "json_schema":
182+
obj = request.response_format.json_schema
183+
if obj:
184+
# guided_json takes str instead of dict obj
185+
sampling_params_dict["guided_json"] = json.dumps(obj.json_schema)
186+
elif request.response_format.type == "json_object":
187+
sampling_params_dict["guided_grammar"] = "json"
188+
183189
sampling_params = SamplingParams()
184190
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict)
185191

@@ -453,6 +459,15 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) ->
453459
"add_special_tokens": False,
454460
}
455461

462+
if request.response_format:
463+
if request.response_format.type == "json_schema":
464+
obj = request.response_format.json_schema
465+
if obj:
466+
# guided_json takes str instead of dict obj
467+
sampling_params_dict["guided_json"] = json.dumps(obj.json_schema)
468+
elif request.response_format.type == "json_object":
469+
sampling_params_dict["guided_grammar"] = "json"
470+
456471
sampling_params = SamplingParams()
457472
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict)
458473
sampling_params.verify()

lightllm/server/core/objs/sampling_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def initialize(self, constraint: str, tokenizer):
142142
ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes))
143143
self.length = len(constraint_bytes)
144144
try:
145-
if self.length > 0 and tokenizer is not None:
145+
if self.length > 0 and tokenizer is not None and constraint != "json":
146146
import xgrammar as xgr
147147

148148
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def get_cached_grammar(type: str, grammar: str):
3535
logger.info(f"grammar cache miss for {type}: '{grammar}'")
3636
try:
3737
if type == "grammar":
38-
return self.xgrammar_compiler.compile_grammar(grammar)
38+
if grammar == "json":
39+
return self.xgrammar_compiler.compile_builtin_json_grammar()
40+
41+
else:
42+
return self.xgrammar_compiler.compile_grammar(grammar)
3943
elif type == "schema":
4044
return self.xgrammar_compiler.compile_json_schema(grammar)
4145
else:

test/test_api/test_openai_api.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,72 @@ def test_multiple_token_arrays():
635635
print(f"错误: {e}")
636636

637637

638+
def test_structured_generation():
639+
"""测试结构化生成功能"""
640+
client = LightLLMClient()
641+
642+
try:
643+
print("=== 测试结构化生成 ===")
644+
prompt = "请以JSON格式提供一个包含姓名、年龄和职业的人的信息。"
645+
646+
# 测试JSON生成
647+
result = client.completions(prompt, max_tokens=150, response_format={"type": "json_object"})
648+
print("提示:", prompt)
649+
print("助手:", result["choices"][0]["text"])
650+
651+
# 测试JSON Schema生成
652+
schema = {
653+
"type": "object",
654+
"properties": {
655+
"name": {"type": "string"},
656+
"age": {"type": "integer"},
657+
"occupation": {"type": "string"},
658+
},
659+
"required": ["name", "age", "occupation"],
660+
}
661+
result = client.completions(
662+
prompt,
663+
max_tokens=150,
664+
response_format={
665+
"type": "json_schema",
666+
"json_schema": {
667+
"name": "PersonInfo",
668+
"description": "包含姓名、年龄和职业的人的信息",
669+
"schema": schema,
670+
},
671+
},
672+
)
673+
print("提示:", prompt)
674+
print("助手:", result["choices"][0]["text"])
675+
676+
# 测试/v1/chat/completions端点的JSON生成
677+
result = client.simple_chat(
678+
prompt,
679+
max_tokens=150,
680+
response_format={"type": "json_object"},
681+
)
682+
print("提示:", prompt)
683+
print("助手:", result["choices"][0]["message"]["content"])
684+
685+
# 测试/v1/chat/completions端点的JSON Schema生成
686+
result = client.simple_chat(
687+
prompt,
688+
max_tokens=150,
689+
response_format={
690+
"type": "json_schema",
691+
"json_schema": {
692+
"name": "PersonInfo",
693+
"description": "包含姓名、年龄和职业的人的信息",
694+
"schema": schema,
695+
},
696+
},
697+
)
698+
print("提示:", prompt)
699+
print("助手:", result["choices"][0]["message"]["content"])
700+
except Exception as e:
701+
print(f"错误: {e}")
702+
703+
638704
def main():
639705
# 基础功能测试
640706
test_completions()
@@ -651,6 +717,7 @@ def main():
651717
test_logprobs()
652718
test_echo()
653719
test_stop_parameter()
720+
test_structured_generation()
654721

655722

656723
if __name__ == "__main__":

0 commit comments

Comments
 (0)