Skip to content

Commit e4e9c43

Browse files
pwilkinggerganov
andauthored
Make graph_max_nodes vary by ubatch size (#17794)
* Make graph_max_nodes vary by ubatch size for models where chunking might explode the graph * Update src/llama-context.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Add missing const --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 636fc17 commit e4e9c43

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/llama-context.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,10 @@ llama_context::llama_context(
248248

249249
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
250250

251-
const size_t max_nodes = this->graph_max_nodes();
251+
const uint32_t n_seqs = cparams.n_seq_max;
252+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
253+
254+
const size_t max_nodes = this->graph_max_nodes(n_tokens);
252255

253256
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
254257

@@ -300,9 +303,6 @@ llama_context::llama_context(
300303

301304
cross.v_embd.clear();
302305

303-
const uint32_t n_seqs = cparams.n_seq_max;
304-
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
305-
306306
// avoid reserving graphs with zero outputs - assume one output per sequence
307307
n_outputs = n_seqs;
308308

@@ -1386,9 +1386,9 @@ void llama_context::output_reorder() {
13861386
// graph
13871387
//
13881388

1389-
uint32_t llama_context::graph_max_nodes() const {
1389+
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
13901390
if (model.arch == LLM_ARCH_QWEN3NEXT) {
1391-
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
1391+
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
13921392
}
13931393
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
13941394
}

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ struct llama_context {
197197
//
198198

199199
public:
200-
uint32_t graph_max_nodes() const;
200+
uint32_t graph_max_nodes(uint32_t n_tokens) const;
201201

202202
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
203203
llm_graph_result * get_gf_res_reserve() const;

0 commit comments

Comments
 (0)