Skip to content

Commit 09519f7

Browse files
committed
server : make cache_reuse configurable per request
1 parent 79d6189 commit 09519f7

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

tools/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re
495495

496496
`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
497497

498+
`n_cache_reuse`: Min chunk size to attempt reusing from the cache via KV shifting. For more info, see `--cache-reuse` arg. Default: `0`, which is disabled.
499+
498500
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
499501

500502
`stop`: Specify a JSON array of stopping strings.

tools/server/server-context.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,8 +1880,18 @@ struct server_context_impl {
18801880
n_past = std::min(n_past, slot.alora_invocation_start - 1);
18811881
}
18821882

1883+
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
1884+
1885+
const bool can_cache_reuse =
1886+
llama_memory_can_shift(llama_get_memory(ctx)) &&
1887+
!slot.prompt.tokens.has_mtmd;
1888+
1889+
if (!can_cache_reuse && n_cache_reuse > 0) {
1890+
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
1891+
}
1892+
18831893
// reuse chunks from the cached prompt by shifting their KV cache in the new position
1884-
if (params_base.n_cache_reuse > 0) {
1894+
if (can_cache_reuse && n_cache_reuse > 0) {
18851895
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
18861896

18871897
size_t head_c = n_past; // cache
@@ -1892,7 +1902,7 @@ struct server_context_impl {
18921902
GGML_ABORT("not supported by multimodal");
18931903
}
18941904

1895-
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
1905+
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
18961906

18971907
while (head_c < slot.prompt.tokens.size() &&
18981908
head_p < input_tokens.size()) {
@@ -1901,11 +1911,10 @@ struct server_context_impl {
19011911
while (head_c + n_match < slot.prompt.tokens.size() &&
19021912
head_p + n_match < input_tokens.size() &&
19031913
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
1904-
19051914
n_match++;
19061915
}
19071916

1908-
if (n_match >= (size_t) params_base.n_cache_reuse) {
1917+
if (n_match >= (size_t) n_cache_reuse) {
19091918
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
19101919
//for (size_t i = head_p; i < head_p + n_match; i++) {
19111920
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());

tools/server/server-task.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,12 @@ task_params server_task::params_from_json_cmpl(
155155

156156
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
157157
task_params defaults;
158-
defaults.sampling = params_base.sampling;
159-
defaults.speculative = params_base.speculative;
160-
defaults.n_keep = params_base.n_keep;
161-
defaults.n_predict = params_base.n_predict;
162-
defaults.antiprompt = params_base.antiprompt;
158+
defaults.sampling = params_base.sampling;
159+
defaults.speculative = params_base.speculative;
160+
defaults.n_keep = params_base.n_keep;
161+
defaults.n_predict = params_base.n_predict;
162+
defaults.n_cache_reuse = params_base.n_cache_reuse;
163+
defaults.antiprompt = params_base.antiprompt;
163164

164165
// enabling this will output extra debug information in the HTTP responses from the server
165166
params.verbose = params_base.verbosity > 9;
@@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl(
176177
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
177178
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
178179
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
180+
params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse);
179181
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
180182
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
181183
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());

tools/server/server-task.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,25 +55,28 @@ struct task_params {
5555
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
5656
int32_t n_cmpl = 1; // number of completions to generate from this prompt
5757

58+
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
59+
5860
int64_t t_max_prompt_ms = -1; // TODO: implement
5961
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
6062

6163
std::vector<common_adapter_lora_info> lora;
6264

6365
std::vector<std::string> antiprompt;
6466
std::vector<std::string> response_fields;
65-
bool timings_per_token = false;
67+
68+
bool timings_per_token = false;
6669
bool post_sampling_probs = false;
6770

6871
struct common_params_sampling sampling;
6972
struct common_params_speculative speculative;
7073

7174
// response formatting
72-
bool verbose = false;
73-
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
74-
std::string oaicompat_model;
75-
std::string oaicompat_cmpl_id;
76-
common_chat_syntax oaicompat_chat_syntax;
75+
bool verbose = false;
76+
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
77+
std::string oaicompat_model;
78+
std::string oaicompat_cmpl_id;
79+
common_chat_syntax oaicompat_chat_syntax;
7780

7881
// Embeddings
7982
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)

0 commit comments

Comments
 (0)