Skip to content

Commit 5fcd220

Browse files
committed
feat: add usage to streamin response
1 parent 63fc309 commit 5fcd220

File tree

3 files changed

+140
-82
lines changed

3 files changed

+140
-82
lines changed

llama_cpp/llama.py

Lines changed: 110 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,50 @@ def decode_batch(seq_sizes: List[int]):
10571057
else:
10581058
return output
10591059

1060+
def _create_chunk(
1061+
self,
1062+
completion_id: str,
1063+
created: int,
1064+
model_name: str,
1065+
text: str,
1066+
logprobs_or_none: Union[Optional[CompletionLogprobs], None],
1067+
index: int,
1068+
finish_reason: Union[str, None],
1069+
usage: Optional[Dict[str, Any]] = None,
1070+
) -> CreateCompletionStreamResponse:
1071+
"""Create chunks for streaming API, depending on whether usage is requested or not."""
1072+
if usage is not None:
1073+
return {
1074+
"id": completion_id,
1075+
"object": "text_completion",
1076+
"created": created,
1077+
"model": model_name,
1078+
"choices": [
1079+
{
1080+
"text": text,
1081+
"index": index,
1082+
"logprobs": logprobs_or_none,
1083+
"finish_reason": finish_reason,
1084+
}
1085+
],
1086+
"usage": usage,
1087+
}
1088+
else:
1089+
return {
1090+
"id": completion_id,
1091+
"object": "text_completion",
1092+
"created": created,
1093+
"model": model_name,
1094+
"choices": [
1095+
{
1096+
"text": text,
1097+
"index": index,
1098+
"logprobs": logprobs_or_none,
1099+
"finish_reason": finish_reason,
1100+
}
1101+
],
1102+
}
1103+
10601104
def _create_completion(
10611105
self,
10621106
prompt: Union[str, List[int]],
@@ -1383,24 +1427,20 @@ def logit_bias_processor(
13831427
"top_logprobs": [top_logprob],
13841428
}
13851429
returned_tokens += 1
1386-
yield {
1387-
"id": completion_id,
1388-
"object": "text_completion",
1389-
"created": created,
1390-
"model": model_name,
1391-
"choices": [
1392-
{
1393-
"text": self.detokenize(
1394-
[token],
1395-
prev_tokens=prompt_tokens
1396-
+ completion_tokens[:returned_tokens],
1397-
).decode("utf-8", errors="ignore"),
1398-
"index": 0,
1399-
"logprobs": logprobs_or_none,
1400-
"finish_reason": None,
1401-
}
1402-
],
1403-
}
1430+
yield self._create_chunk(
1431+
completion_id=completion_id,
1432+
created=created,
1433+
model_name=model_name,
1434+
text=self.detokenize(
1435+
[token],
1436+
prev_tokens=prompt_tokens
1437+
+ completion_tokens[:returned_tokens],
1438+
).decode("utf-8", errors="ignore"),
1439+
logprobs_or_none=logprobs_or_none,
1440+
index=0,
1441+
finish_reason=None,
1442+
usage=None,
1443+
)
14041444
else:
14051445
while len(remaining_tokens) > 0:
14061446
decode_success = False
@@ -1429,20 +1469,16 @@ def logit_bias_processor(
14291469
remaining_tokens = remaining_tokens[i:]
14301470
returned_tokens += i
14311471

1432-
yield {
1433-
"id": completion_id,
1434-
"object": "text_completion",
1435-
"created": created,
1436-
"model": model_name,
1437-
"choices": [
1438-
{
1439-
"text": ts,
1440-
"index": 0,
1441-
"logprobs": None,
1442-
"finish_reason": None,
1443-
}
1444-
],
1445-
}
1472+
yield self._create_chunk(
1473+
completion_id=completion_id,
1474+
created=created,
1475+
model_name=model_name,
1476+
text=ts,
1477+
logprobs_or_none=None,
1478+
index=0,
1479+
finish_reason=None,
1480+
usage=None,
1481+
)
14461482

14471483
if len(completion_tokens) >= max_tokens:
14481484
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -1521,54 +1557,51 @@ def logit_bias_processor(
15211557
if token_end_position == end - 1:
15221558
break
15231559
returned_tokens += 1
1524-
yield {
1525-
"id": completion_id,
1526-
"object": "text_completion",
1527-
"created": created,
1528-
"model": model_name,
1529-
"choices": [
1530-
{
1531-
"text": last_text[
1532-
: len(last_text) - (token_end_position - end)
1533-
].decode("utf-8", errors="ignore"),
1534-
"index": 0,
1535-
"logprobs": logprobs_or_none,
1536-
"finish_reason": None,
1537-
}
1538-
],
1539-
}
1560+
yield self._create_chunk(
1561+
completion_id=completion_id,
1562+
created=created,
1563+
model_name=model_name,
1564+
text=last_text[
1565+
: len(last_text) - (token_end_position - end)
1566+
].decode("utf-8", errors="ignore"),
1567+
logprobs_or_none=logprobs_or_none,
1568+
index=0,
1569+
finish_reason=None,
1570+
usage=None,
1571+
)
15401572
break
15411573
returned_tokens += 1
1542-
yield {
1543-
"id": completion_id,
1544-
"object": "text_completion",
1545-
"created": created,
1546-
"model": model_name,
1547-
"choices": [
1548-
{
1549-
"text": self.detokenize([token]).decode(
1550-
"utf-8", errors="ignore"
1551-
),
1552-
"index": 0,
1553-
"logprobs": logprobs_or_none,
1554-
"finish_reason": None,
1555-
}
1556-
],
1557-
}
1558-
yield {
1559-
"id": completion_id,
1560-
"object": "text_completion",
1561-
"created": created,
1562-
"model": model_name,
1563-
"choices": [
1564-
{
1565-
"text": "",
1566-
"index": 0,
1567-
"logprobs": None,
1568-
"finish_reason": finish_reason,
1569-
}
1570-
],
1574+
yield self._create_chunk(
1575+
completion_id=completion_id,
1576+
created=created,
1577+
model_name=model_name,
1578+
text=self.detokenize([token]).decode(
1579+
"utf-8", errors="ignore"
1580+
),
1581+
logprobs_or_none=logprobs_or_none,
1582+
index=0,
1583+
finish_reason=None,
1584+
usage=None,
1585+
)
1586+
1587+
# Final streaming chunk with both finish_reason and usage
1588+
usage = {
1589+
"prompt_tokens": len(prompt_tokens),
1590+
"completion_tokens": returned_tokens,
1591+
"total_tokens": len(prompt_tokens) + returned_tokens,
15711592
}
1593+
1594+
yield self._create_chunk(
1595+
completion_id=completion_id,
1596+
created=created,
1597+
model_name=model_name,
1598+
text="",
1599+
logprobs_or_none=None,
1600+
index=0,
1601+
finish_reason=finish_reason,
1602+
usage=usage,
1603+
)
1604+
15721605
if self.cache:
15731606
if self.verbose:
15741607
print("Llama._create_completion: cache save", file=sys.stderr)

llama_cpp/llama_chat_format.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def _convert_text_completion_chunks_to_chat(
347347
"finish_reason": chunk["choices"][0]["finish_reason"],
348348
}
349349
],
350+
"usage": chunk.get("usage") if "usage" in chunk else None,
350351
}
351352

352353

@@ -431,7 +432,7 @@ def _stream_response_to_function_stream(
431432
created = chunk["created"]
432433
model = chunk["model"]
433434
tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"]
434-
yield {
435+
response = {
435436
"id": id_,
436437
"object": "chat.completion.chunk",
437438
"created": created,
@@ -450,7 +451,11 @@ def _stream_response_to_function_stream(
450451
}
451452
],
452453
}
453-
yield {
454+
if "usage" in chunk:
455+
response["usage"] = chunk["usage"]
456+
yield response
457+
458+
response = {
454459
"id": "chat" + chunk["id"],
455460
"object": "chat.completion.chunk",
456461
"created": chunk["created"],
@@ -484,10 +489,14 @@ def _stream_response_to_function_stream(
484489
}
485490
],
486491
}
492+
if "usage" in chunk:
493+
response["usage"] = chunk["usage"]
494+
yield response
487495
first = False
488496
continue
497+
489498
assert tool_id is not None
490-
yield {
499+
response = {
491500
"id": "chat" + chunk["id"],
492501
"object": "chat.completion.chunk",
493502
"created": chunk["created"],
@@ -519,9 +528,12 @@ def _stream_response_to_function_stream(
519528
}
520529
],
521530
}
531+
if "usage" in chunk:
532+
response["usage"] = chunk["usage"]
533+
yield response
522534

523535
if id_ is not None and created is not None and model is not None:
524-
yield {
536+
response = {
525537
"id": id_,
526538
"object": "chat.completion.chunk",
527539
"created": created,
@@ -540,6 +552,9 @@ def _stream_response_to_function_stream(
540552
}
541553
],
542554
}
555+
if "usage" in chunk:
556+
response["usage"] = chunk["usage"]
557+
yield response
543558

544559
return _stream_response_to_function_stream(chunks)
545560

@@ -2120,6 +2135,7 @@ def generate_streaming(tools, functions, function_call, prompt):
21202135
},
21212136
}
21222137
],
2138+
usage=chunk["usage"] if "usage" in chunk else None,
21232139
)
21242140
first = False
21252141
if tools is not None:
@@ -2160,6 +2176,7 @@ def generate_streaming(tools, functions, function_call, prompt):
21602176
},
21612177
}
21622178
],
2179+
usage=chunk["usage"] if "usage" in chunk else None,
21632180
)
21642181
# Yield tool_call/function_call stop message
21652182
yield llama_types.CreateChatCompletionStreamResponse(
@@ -2182,6 +2199,7 @@ def generate_streaming(tools, functions, function_call, prompt):
21822199
},
21832200
}
21842201
],
2202+
usage=chunk["usage"] if "usage" in chunk else None,
21852203
)
21862204
# If "auto" or no tool_choice/function_call
21872205
elif isinstance(function_call, str) and function_call == "auto":
@@ -2217,6 +2235,7 @@ def generate_streaming(tools, functions, function_call, prompt):
22172235
"finish_reason": None,
22182236
}
22192237
],
2238+
usage=chunk["usage"] if "usage" in chunk else None,
22202239
)
22212240
else:
22222241
prompt += f"{function_name}\n<|content|>"
@@ -2262,6 +2281,7 @@ def generate_streaming(tools, functions, function_call, prompt):
22622281
},
22632282
}
22642283
],
2284+
usage=chunk["usage"] if "usage" in chunk else None,
22652285
)
22662286
# Generate content
22672287
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
@@ -2299,6 +2319,7 @@ def generate_streaming(tools, functions, function_call, prompt):
22992319
},
23002320
}
23012321
],
2322+
usage=chunk["usage"] if "usage" in chunk else None,
23022323
)
23032324
is_end = False
23042325
elif chunk["choices"][0]["text"] == "\n":
@@ -2328,6 +2349,7 @@ def generate_streaming(tools, functions, function_call, prompt):
23282349
},
23292350
}
23302351
],
2352+
usage=chunk["usage"] if "usage" in chunk else None,
23312353
)
23322354
# Check whether the model wants to generate another turn
23332355
if (
@@ -2360,6 +2382,7 @@ def generate_streaming(tools, functions, function_call, prompt):
23602382
"finish_reason": "stop",
23612383
}
23622384
],
2385+
usage=chunk["usage"] if "usage" in chunk else None,
23632386
)
23642387
break
23652388
else:
@@ -2409,6 +2432,7 @@ def generate_streaming(tools, functions, function_call, prompt):
24092432
},
24102433
}
24112434
],
2435+
usage=chunk["usage"] if "usage" in chunk else None,
24122436
)
24132437
prompt += completion_text.strip()
24142438
grammar = None
@@ -2448,6 +2472,7 @@ def generate_streaming(tools, functions, function_call, prompt):
24482472
},
24492473
}
24502474
],
2475+
usage=chunk["usage"] if "usage" in chunk else None,
24512476
)
24522477
break
24532478

llama_cpp/llama_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ class ChatCompletionStreamResponseChoice(TypedDict):
154154
finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]]
155155
logprobs: NotRequired[Optional[ChatCompletionLogprobs]]
156156

157-
158157
class CreateChatCompletionStreamResponse(TypedDict):
159158
id: str
160159
model: str
161160
object: Literal["chat.completion.chunk"]
162161
created: int
163162
choices: List[ChatCompletionStreamResponseChoice]
163+
usage: NotRequired[CompletionUsage]
164164

165165

166166
class ChatCompletionFunctions(TypedDict):

0 commit comments

Comments
 (0)