Skip to content

Commit f896d2c

Browse files
ngxsonggerganov
andauthored
server: improve speed of speculative decoding (#17808)
* server: improve speed of speculative decoding * fix small draft case * add link to the PR * server : fix generation time measurement * server : fix draft acceptance logs (add SRV_CNT, SLT_CNT macros) * server : add comment * add PR to docs --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent e4e9c43 commit f896d2c

File tree

3 files changed

+108
-76
lines changed

3 files changed

+108
-76
lines changed

tools/server/README-dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ For detailed instructions, see the [test documentation](./tests/README.md).
8181
- Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216
8282
- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362
8383
- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470
84+
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
8485

8586

8687

tools/server/server-common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "
1818
using json = nlohmann::ordered_json;
1919

2020
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
21+
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
2122
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
2223
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
2324
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
2425

2526
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
27+
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
2628
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
2729
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
2830
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)

tools/server/server-context.cpp

Lines changed: 105 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ struct server_slot {
102102
std::string generated_text;
103103
llama_tokens generated_tokens;
104104

105+
// idx of draft tokens in the main batch
106+
// non-empty if we went to evaluate draft tokens
107+
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
108+
std::vector<int32_t> i_batch_dft;
109+
105110
std::vector<completion_token_output> generated_token_probs;
106111

107112
bool has_next_token = true;
@@ -150,7 +155,8 @@ struct server_slot {
150155

151156
struct common_sampler * smpl = nullptr;
152157

153-
llama_token sampled;
158+
llama_token sampled; // in speculative mode, this is the last accepted token
159+
llama_tokens drafted;
154160

155161
// stats
156162
size_t n_sent_text = 0; // number of sent text character
@@ -180,6 +186,8 @@ struct server_slot {
180186
stopping_word = "";
181187
n_sent_text = 0;
182188

189+
drafted.clear();
190+
i_batch_dft.clear();
183191
generated_tokens.clear();
184192
generated_token_probs.clear();
185193
json_schema = json();
@@ -255,6 +263,31 @@ struct server_slot {
255263
generated_token_probs.push_back(token);
256264
}
257265

266+
int get_n_draft_max() const {
267+
if (!can_speculate()) {
268+
return 0;
269+
}
270+
271+
// determine the max draft that fits the current slot state
272+
int n_draft_max = task->params.speculative.n_max;
273+
274+
// note: slot.prompt is not yet expanded with the `id` token sampled above
275+
// also, need to leave space for 1 extra token to allow context shifts
276+
n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
277+
278+
if (n_remaining > 0) {
279+
n_draft_max = std::min(n_draft_max, n_remaining - 1);
280+
}
281+
282+
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
283+
284+
if (n_draft_max < task->params.speculative.n_min) {
285+
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
286+
n_draft_max = 0;
287+
}
288+
return n_draft_max;
289+
}
290+
258291
// note: a slot can also be either a parent or a child
259292
bool is_parent() const {
260293
return is_processing() && task->n_children > 0;
@@ -353,8 +386,7 @@ struct server_slot {
353386

354387
if (n_draft_total > 0) {
355388
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
356-
SLT_INF(*this,
357-
"\n"
389+
SLT_CNT(*this,
358390
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
359391
draft_ratio, n_draft_accepted, n_draft_total
360392
);
@@ -1774,14 +1806,57 @@ struct server_context_impl {
17741806
continue;
17751807
}
17761808

1777-
slot.i_batch = batch.n_tokens;
1809+
// generate draft tokens in speculative decoding mode
1810+
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
1811+
// perform the speculative drafting for all sequences at the same time in a single batch
1812+
int n_draft_max = slot.get_n_draft_max();
1813+
if (n_draft_max > 0) {
1814+
if (mctx) {
1815+
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
1816+
GGML_ABORT("not supported by multimodal");
1817+
}
17781818

1779-
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
1819+
struct common_speculative_params params_spec;
1820+
params_spec.n_draft = n_draft_max;
1821+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
1822+
params_spec.p_min = slot.task->params.speculative.p_min;
1823+
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
1824+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
1825+
1826+
// add the sampled token to the batch
1827+
slot.i_batch_dft.push_back(batch.n_tokens);
1828+
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
1829+
slot.prompt.tokens.push_back(slot.sampled);
1830+
1831+
if (slot.task->params.speculative.n_min > (int) draft.size()) {
1832+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
1833+
// fallback to normal decoding
1834+
slot.i_batch = slot.i_batch_dft[0];
1835+
slot.drafted.clear();
1836+
slot.i_batch_dft.clear();
1837+
} else {
1838+
// keep track of total number of drafted tokens tested
1839+
slot.n_draft_total += draft.size();
1840+
1841+
// add all drafted tokens to the batch
1842+
for (size_t i = 0; i < draft.size(); i++) {
1843+
slot.i_batch_dft.push_back(batch.n_tokens);
1844+
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
1845+
slot.prompt.tokens.push_back(draft[i]);
1846+
}
1847+
slot.drafted = std::move(draft);
1848+
}
1849+
} else {
1850+
// no speculative decoding
1851+
slot.i_batch = batch.n_tokens;
17801852

1781-
slot.prompt.tokens.push_back(slot.sampled);
1853+
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
17821854

1783-
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
1784-
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
1855+
slot.prompt.tokens.push_back(slot.sampled);
1856+
1857+
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
1858+
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
1859+
}
17851860
}
17861861

17871862
// process in chunks of params.n_batch
@@ -2345,6 +2420,10 @@ struct server_context_impl {
23452420
// on successful decode, restore the original batch size
23462421
n_batch = llama_n_batch(ctx);
23472422

2423+
// technically, measuring the time here excludes the sampling time for the last batch
2424+
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
2425+
const int64_t t_current = ggml_time_us();
2426+
23482427
for (auto & slot : slots) {
23492428
// may need to copy state to other slots
23502429
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
@@ -2399,6 +2478,10 @@ struct server_context_impl {
23992478
continue; // continue loop of slots
24002479
}
24012480

2481+
if (slot.i_batch_dft.size() > 0) {
2482+
continue; // sample using speculative decoding
2483+
}
2484+
24022485
const int tok_idx = slot.i_batch - i;
24032486

24042487
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
@@ -2409,8 +2492,6 @@ struct server_context_impl {
24092492

24102493
slot.n_decoded += 1;
24112494

2412-
const int64_t t_current = ggml_time_us();
2413-
24142495
if (slot.n_decoded == 1) {
24152496
slot.t_start_generation = t_current;
24162497
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
@@ -2439,84 +2520,32 @@ struct server_context_impl {
24392520
}
24402521
}
24412522

2442-
// do speculative decoding
2443-
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
2444-
// perform the speculative drafting for all sequences at the same time in a single batch
2523+
// speculative decoding - main model sample and accept
24452524
for (auto & slot : slots) {
2446-
if (!slot.is_processing() || !slot.can_speculate()) {
2525+
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
24472526
continue;
24482527
}
24492528

2450-
if (slot.state != SLOT_STATE_GENERATING) {
2451-
continue;
2452-
}
2453-
2454-
if (mctx) {
2455-
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
2456-
GGML_ABORT("not supported by multimodal");
2457-
}
2458-
2459-
// determine the max draft that fits the current slot state
2460-
int n_draft_max = slot.task->params.speculative.n_max;
2461-
2462-
// note: slot.prompt is not yet expanded with the `id` token sampled above
2463-
// also, need to leave space for 1 extra token to allow context shifts
2464-
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
2465-
2466-
if (slot.n_remaining > 0) {
2467-
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
2468-
}
2469-
2470-
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
2471-
2472-
if (n_draft_max < slot.task->params.speculative.n_min) {
2473-
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
2474-
2475-
continue;
2476-
}
2477-
2478-
llama_token id = slot.sampled;
2479-
2480-
struct common_speculative_params params_spec;
2481-
params_spec.n_draft = n_draft_max;
2482-
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
2483-
params_spec.p_min = slot.task->params.speculative.p_min;
2484-
2485-
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
2486-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
2487-
2488-
// ignore small drafts
2489-
if (slot.task->params.speculative.n_min > (int) draft.size()) {
2490-
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
2491-
2492-
continue;
2493-
}
2494-
2495-
// keep track of total number of drafted tokens tested
2496-
slot.n_draft_total += draft.size();
2497-
2498-
// construct the speculation batch
2499-
common_batch_clear(slot.batch_spec);
2500-
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
2501-
2502-
for (size_t i = 0; i < draft.size(); ++i) {
2503-
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
2504-
}
2505-
2506-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
2507-
2508-
llama_decode(ctx, slot.batch_spec);
2529+
size_t n_draft = slot.drafted.size();
25092530

25102531
// the accepted tokens from the speculation
2511-
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
2532+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
2533+
slot.i_batch_dft.clear();
2534+
slot.drafted.clear();
25122535

25132536
slot.n_decoded += ids.size();
25142537

2538+
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
2539+
25152540
// update how many tokens out of those tested were accepted
25162541
slot.n_draft_accepted += ids.size() - 1;
25172542

2518-
slot.prompt.tokens.push_back(id);
2543+
// rollback to the state before sampling the draft tokens
2544+
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
2545+
2546+
// add accepted tokens to the prompt
25192547
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
2548+
slot.sampled = ids.back(); // last accepted token
25202549

25212550
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
25222551

@@ -2539,7 +2568,7 @@ struct server_context_impl {
25392568
}
25402569
}
25412570

2542-
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
2571+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
25432572
}
25442573
}
25452574

0 commit comments

Comments
 (0)