Skip to content

Commit 59977eb

Browse files
yifant-codeytian218ggerganov
authored
server: fix crash when batch > ubatch with embeddings (#17912)
* server: fix crash when batch > ubatch with embeddings (#12836) Fixes #12836 where the server crashes with GGML_ASSERT failure when running with embeddings enabled and n_batch > n_ubatch. Root cause: Embeddings use non-causal attention which requires all tokens to be processed within a single ubatch. When n_batch > n_ubatch, the server attempts to split processing, causing assertion failure. Solution: - Add parameter validation in main() after common_params_parse() - When embeddings enabled and n_batch > n_ubatch: * Log warnings explaining the issue * Automatically set n_batch = n_ubatch * Prevent server crash This follows the approach suggested by @ggerganov in issue #12836. Note: This supersedes stalled PR #12940 which attempted a runtime fix in the old examples/server/server.cpp location. This implementation validates at startup in tools/server/server.cpp (current location). Testing: - Build: Compiles successfully - Validation triggers: Warns when -b > -ub with --embedding - Auto-correction works: Adjusts n_batch = n_ubatch - No false positives: Valid params don't trigger warnings - Verified on macOS M3 Pro with embedding model * Update tools/server/server.cpp --------- Co-authored-by: ytian218 <ytian218@bloomberg.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 79dbae0 commit 59977eb

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tools/server/server.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,18 @@ int main(int argc, char ** argv, char ** envp) {
7373
return 1;
7474
}
7575

76+
// validate batch size for embeddings
77+
// embeddings require all tokens to be processed in a single ubatch
78+
// see https://github.com/ggml-org/llama.cpp/issues/12836
79+
if (params.embedding && params.n_batch > params.n_ubatch) {
80+
LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
81+
LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
82+
params.n_batch = params.n_ubatch;
83+
}
84+
7685
if (params.n_parallel < 0) {
7786
LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
87+
7888
params.n_parallel = 4;
7989
params.kv_unified = true;
8090
}

0 commit comments

Comments
 (0)