@@ -102,6 +102,11 @@ struct server_slot {
102102 std::string generated_text;
103103 llama_tokens generated_tokens;
104104
105+ // idx of draft tokens in the main batch
106+ // non-empty if we went to evaluate draft tokens
107+ // ref: https://github.com/ggml-org/llama.cpp/pull/17808
108+ std::vector<int32_t > i_batch_dft;
109+
105110 std::vector<completion_token_output> generated_token_probs;
106111
107112 bool has_next_token = true ;
@@ -150,7 +155,8 @@ struct server_slot {
150155
151156 struct common_sampler * smpl = nullptr ;
152157
153- llama_token sampled;
158+ llama_token sampled; // in speculative mode, this is the last accepted token
159+ llama_tokens drafted;
154160
155161 // stats
156162 size_t n_sent_text = 0 ; // number of sent text character
@@ -180,6 +186,8 @@ struct server_slot {
180186 stopping_word = " " ;
181187 n_sent_text = 0 ;
182188
189+ drafted.clear ();
190+ i_batch_dft.clear ();
183191 generated_tokens.clear ();
184192 generated_token_probs.clear ();
185193 json_schema = json ();
@@ -255,6 +263,31 @@ struct server_slot {
255263 generated_token_probs.push_back (token);
256264 }
257265
266+ int get_n_draft_max () const {
267+ if (!can_speculate ()) {
268+ return 0 ;
269+ }
270+
271+ // determine the max draft that fits the current slot state
272+ int n_draft_max = task->params .speculative .n_max ;
273+
274+ // note: slot.prompt is not yet expanded with the `id` token sampled above
275+ // also, need to leave space for 1 extra token to allow context shifts
276+ n_draft_max = std::min (n_draft_max, n_ctx - prompt.n_tokens () - 2 );
277+
278+ if (n_remaining > 0 ) {
279+ n_draft_max = std::min (n_draft_max, n_remaining - 1 );
280+ }
281+
282+ SLT_DBG (*this , " max possible draft: %d\n " , n_draft_max);
283+
284+ if (n_draft_max < task->params .speculative .n_min ) {
285+ SLT_DBG (*this , " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, task->params .speculative .n_min );
286+ n_draft_max = 0 ;
287+ }
288+ return n_draft_max;
289+ }
290+
258291 // note: a slot can also be either a parent or a child
259292 bool is_parent () const {
260293 return is_processing () && task->n_children > 0 ;
@@ -353,8 +386,7 @@ struct server_slot {
353386
354387 if (n_draft_total > 0 ) {
355388 const float draft_ratio = (float ) n_draft_accepted / n_draft_total;
356- SLT_INF (*this ,
357- " \n "
389+ SLT_CNT (*this ,
358390 " draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n " ,
359391 draft_ratio, n_draft_accepted, n_draft_total
360392 );
@@ -1774,14 +1806,57 @@ struct server_context_impl {
17741806 continue ;
17751807 }
17761808
1777- slot.i_batch = batch.n_tokens ;
1809+ // generate draft tokens in speculative decoding mode
1810+ // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
1811+ // perform the speculative drafting for all sequences at the same time in a single batch
1812+ int n_draft_max = slot.get_n_draft_max ();
1813+ if (n_draft_max > 0 ) {
1814+ if (mctx) {
1815+ // we should never reach this, as speculative is automatically disabled if mmproj is loaded
1816+ GGML_ABORT (" not supported by multimodal" );
1817+ }
17781818
1779- common_batch_add (batch, slot.sampled , slot.prompt .tokens .pos_next (), { slot.id }, true );
1819+ struct common_speculative_params params_spec;
1820+ params_spec.n_draft = n_draft_max;
1821+ params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.task ->params .speculative .n_max ;
1822+ params_spec.p_min = slot.task ->params .speculative .p_min ;
1823+ const llama_tokens & cached_text_tokens = slot.prompt .tokens .get_text_tokens ();
1824+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, slot.sampled );
1825+
1826+ // add the sampled token to the batch
1827+ slot.i_batch_dft .push_back (batch.n_tokens );
1828+ common_batch_add (batch, slot.sampled , slot.prompt .tokens .pos_next (), { slot.id }, true );
1829+ slot.prompt .tokens .push_back (slot.sampled );
1830+
1831+ if (slot.task ->params .speculative .n_min > (int ) draft.size ()) {
1832+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.task ->params .speculative .n_min );
1833+ // fallback to normal decoding
1834+ slot.i_batch = slot.i_batch_dft [0 ];
1835+ slot.drafted .clear ();
1836+ slot.i_batch_dft .clear ();
1837+ } else {
1838+ // keep track of total number of drafted tokens tested
1839+ slot.n_draft_total += draft.size ();
1840+
1841+ // add all drafted tokens to the batch
1842+ for (size_t i = 0 ; i < draft.size (); i++) {
1843+ slot.i_batch_dft .push_back (batch.n_tokens );
1844+ common_batch_add (batch, draft[i], slot.prompt .tokens .pos_next (), { slot.id }, true );
1845+ slot.prompt .tokens .push_back (draft[i]);
1846+ }
1847+ slot.drafted = std::move (draft);
1848+ }
1849+ } else {
1850+ // no speculative decoding
1851+ slot.i_batch = batch.n_tokens ;
17801852
1781- slot.prompt .tokens .push_back ( slot.sampled );
1853+ common_batch_add (batch, slot.sampled , slot. prompt .tokens .pos_next (), { slot.id }, true );
17821854
1783- SLT_DBG (slot, " slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n " ,
1784- slot.n_ctx , slot.prompt .n_tokens (), slot.truncated );
1855+ slot.prompt .tokens .push_back (slot.sampled );
1856+
1857+ SLT_DBG (slot, " slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n " ,
1858+ slot.n_ctx , slot.prompt .n_tokens (), slot.truncated );
1859+ }
17851860 }
17861861
17871862 // process in chunks of params.n_batch
@@ -2345,6 +2420,10 @@ struct server_context_impl {
23452420 // on successful decode, restore the original batch size
23462421 n_batch = llama_n_batch (ctx);
23472422
2423+ // technically, measuring the time here excludes the sampling time for the last batch
2424+ // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
2425+ const int64_t t_current = ggml_time_us ();
2426+
23482427 for (auto & slot : slots) {
23492428 // may need to copy state to other slots
23502429 if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent ()) {
@@ -2399,6 +2478,10 @@ struct server_context_impl {
23992478 continue ; // continue loop of slots
24002479 }
24012480
2481+ if (slot.i_batch_dft .size () > 0 ) {
2482+ continue ; // sample using speculative decoding
2483+ }
2484+
24022485 const int tok_idx = slot.i_batch - i;
24032486
24042487 llama_token id = common_sampler_sample (slot.smpl , ctx, tok_idx);
@@ -2409,8 +2492,6 @@ struct server_context_impl {
24092492
24102493 slot.n_decoded += 1 ;
24112494
2412- const int64_t t_current = ggml_time_us ();
2413-
24142495 if (slot.n_decoded == 1 ) {
24152496 slot.t_start_generation = t_current;
24162497 slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
@@ -2439,84 +2520,32 @@ struct server_context_impl {
24392520 }
24402521 }
24412522
2442- // do speculative decoding
2443- // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
2444- // perform the speculative drafting for all sequences at the same time in a single batch
2523+ // speculative decoding - main model sample and accept
24452524 for (auto & slot : slots) {
2446- if (! slot.is_processing () || ! slot.can_speculate ()) {
2525+ if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft . empty ()) {
24472526 continue ;
24482527 }
24492528
2450- if (slot.state != SLOT_STATE_GENERATING) {
2451- continue ;
2452- }
2453-
2454- if (mctx) {
2455- // we should never reach this, as speculative is automatically disabled if mmproj is loaded
2456- GGML_ABORT (" not supported by multimodal" );
2457- }
2458-
2459- // determine the max draft that fits the current slot state
2460- int n_draft_max = slot.task ->params .speculative .n_max ;
2461-
2462- // note: slot.prompt is not yet expanded with the `id` token sampled above
2463- // also, need to leave space for 1 extra token to allow context shifts
2464- n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.prompt .n_tokens () - 2 );
2465-
2466- if (slot.n_remaining > 0 ) {
2467- n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
2468- }
2469-
2470- SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
2471-
2472- if (n_draft_max < slot.task ->params .speculative .n_min ) {
2473- SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.task ->params .speculative .n_min );
2474-
2475- continue ;
2476- }
2477-
2478- llama_token id = slot.sampled ;
2479-
2480- struct common_speculative_params params_spec;
2481- params_spec.n_draft = n_draft_max;
2482- params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.task ->params .speculative .n_max ;
2483- params_spec.p_min = slot.task ->params .speculative .p_min ;
2484-
2485- const llama_tokens & cached_text_tokens = slot.prompt .tokens .get_text_tokens ();
2486- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, cached_text_tokens, id);
2487-
2488- // ignore small drafts
2489- if (slot.task ->params .speculative .n_min > (int ) draft.size ()) {
2490- SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.task ->params .speculative .n_min );
2491-
2492- continue ;
2493- }
2494-
2495- // keep track of total number of drafted tokens tested
2496- slot.n_draft_total += draft.size ();
2497-
2498- // construct the speculation batch
2499- common_batch_clear (slot.batch_spec );
2500- common_batch_add (slot.batch_spec , id, slot.prompt .tokens .pos_next (), { slot.id }, true );
2501-
2502- for (size_t i = 0 ; i < draft.size (); ++i) {
2503- common_batch_add (slot.batch_spec , draft[i], slot.prompt .tokens .pos_next () + 1 + i, { slot.id }, true );
2504- }
2505-
2506- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
2507-
2508- llama_decode (ctx, slot.batch_spec );
2529+ size_t n_draft = slot.drafted .size ();
25092530
25102531 // the accepted tokens from the speculation
2511- const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
2532+ const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, slot.i_batch_dft , slot.drafted );
2533+ slot.i_batch_dft .clear ();
2534+ slot.drafted .clear ();
25122535
25132536 slot.n_decoded += ids.size ();
25142537
2538+ slot.t_token_generation = std::max<int64_t >(1 , t_current - slot.t_start_generation ) / 1e3 ;
2539+
25152540 // update how many tokens out of those tested were accepted
25162541 slot.n_draft_accepted += ids.size () - 1 ;
25172542
2518- slot.prompt .tokens .push_back (id);
2543+ // rollback to the state before sampling the draft tokens
2544+ slot.prompt .tokens .keep_first (slot.prompt .n_tokens () - n_draft);
2545+
2546+ // add accepted tokens to the prompt
25192547 slot.prompt .tokens .insert ({ids.begin (), ids.end () - 1 });
2548+ slot.sampled = ids.back (); // last accepted token
25202549
25212550 llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.prompt .n_tokens (), -1 );
25222551
@@ -2539,7 +2568,7 @@ struct server_context_impl {
25392568 }
25402569 }
25412570
2542- SLT_DBG (slot, " accepted %d/%d draft tokens, new n_tokens = %d\n " , (int ) ids.size () - 1 , (int ) draft .size (), slot.prompt .n_tokens ());
2571+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_tokens = %d\n " , (int ) ids.size () - 1 , (int ) slot. drafted .size (), slot.prompt .n_tokens ());
25432572 }
25442573 }
25452574
0 commit comments