Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "
using json = nlohmann::ordered_json;

#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)

#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
Expand Down
181 changes: 105 additions & 76 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;

// idx of draft tokens in the main batch
// non-empty if we went to evaluate draft tokens
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
std::vector<int32_t> i_batch_dft;

std::vector<completion_token_output> generated_token_probs;

bool has_next_token = true;
Expand Down Expand Up @@ -149,7 +154,8 @@ struct server_slot {

struct common_sampler * smpl = nullptr;

llama_token sampled;
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;

// stats
size_t n_sent_text = 0; // number of sent text character
Expand Down Expand Up @@ -179,6 +185,8 @@ struct server_slot {
stopping_word = "";
n_sent_text = 0;

drafted.clear();
i_batch_dft.clear();
generated_tokens.clear();
generated_token_probs.clear();
json_schema = json();
Expand Down Expand Up @@ -254,6 +262,31 @@ struct server_slot {
generated_token_probs.push_back(token);
}

int get_n_draft_max() const {
if (!can_speculate()) {
return 0;
}

// determine the max draft that fits the current slot state
int n_draft_max = task->params.speculative.n_max;

// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);

if (n_remaining > 0) {
n_draft_max = std::min(n_draft_max, n_remaining - 1);
}

SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);

if (n_draft_max < task->params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
n_draft_max = 0;
}
return n_draft_max;
}

void release() {
if (is_processing()) {
GGML_ASSERT(task);
Expand Down Expand Up @@ -343,8 +376,7 @@ struct server_slot {

if (n_draft_total > 0) {
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
SLT_INF(*this,
"\n"
SLT_CNT(*this,
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total
);
Expand Down Expand Up @@ -1745,14 +1777,59 @@ struct server_context_impl {
continue;
}

slot.i_batch = batch.n_tokens;
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}

struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);

// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);

if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();

} else {

// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();

common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
}
} else {
// no speculative decoding
slot.i_batch = batch.n_tokens;

slot.prompt.tokens.push_back(slot.sampled);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
slot.prompt.tokens.push_back(slot.sampled);

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
}

// process in chunks of params.n_batch
Expand Down Expand Up @@ -2307,6 +2384,8 @@ struct server_context_impl {
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx);

const int64_t t_current = ggml_time_us();

for (auto & slot : slots) {
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
Expand Down Expand Up @@ -2341,6 +2420,10 @@ struct server_context_impl {
continue; // continue loop of slots
}

if (slot.i_batch_dft.size() > 0) {
continue; // sample using speculative decoding
}

const int tok_idx = slot.i_batch - i;

llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
Expand All @@ -2351,8 +2434,6 @@ struct server_context_impl {

slot.n_decoded += 1;

const int64_t t_current = ggml_time_us();

if (slot.n_decoded == 1) {
slot.t_start_generation = t_current;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
Expand Down Expand Up @@ -2381,84 +2462,32 @@ struct server_context_impl {
}
}

// do speculative decoding
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
// speculative decoding - main model sample and accept
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
continue;
}

if (slot.state != SLOT_STATE_GENERATING) {
continue;
}

if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}

// determine the max draft that fits the current slot state
int n_draft_max = slot.task->params.speculative.n_max;

// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);

if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}

SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);

if (n_draft_max < slot.task->params.speculative.n_min) {
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);

continue;
}

llama_token id = slot.sampled;

struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;

const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);

// ignore small drafts
if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);

if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
continue;
}

// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();

// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
}

SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);

llama_decode(ctx, slot.batch_spec);
size_t n_draft = slot.drafted.size();

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();

slot.n_decoded += ids.size();

slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;

// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;

slot.prompt.tokens.push_back(id);
// rollback to the state before sampling the draft tokens
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);

// add accepted tokens to the prompt
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token

llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);

Expand All @@ -2481,7 +2510,7 @@ struct server_context_impl {
}
}

SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
}
}

Expand Down