diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4590b239212..6db46cc550d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5825,9 +5825,11 @@ class Gemma3Model(TextModel): norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value def set_vocab(self): - self._set_vocab_sentencepiece() - - self.gguf_writer.add_add_space_prefix(False) + if (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + self.gguf_writer.add_add_space_prefix(False) + else: + self._set_vocab_gpt2() def set_gguf_parameters(self): hparams = self.hparams @@ -5845,13 +5847,24 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers # attn_logit_softcapping is removed in Gemma3 assert hparams.get("attn_logit_softcapping") is None - self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + if (final_logit_softcap := hparams.get("final_logit_softcapping")): + self.gguf_writer.add_final_logit_softcapping(final_logit_softcap) + if hparams.get("sliding_window_pattern") != 1: + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) if hparams.get("rope_scaling") is not None: - assert hparams["rope_scaling"]["rope_type"] == "linear" - # important: this rope_scaling is only applied for global layers, and not used by 1B model - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) + rope_scaling = hparams["rope_scaling"] + if rope_scaling["rope_type"] == "linear": + # important: this rope_scaling is only applied for global layers, and not used by 1B model + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + elif rope_scaling["rope_type"] == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused @@ -5865,8 +5878,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: - vocab = self._create_vocab_sentencepiece() - tokens = vocab[0] + if (self.dir_model / "tokenizer.model").is_file(): + tokens = self._create_vocab_sentencepiece()[0] + else: + tokens = self.get_vocab_base()[0] data_torch = data_torch[:len(tokens)] # ref code in Gemma3RMSNorm diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fbd538109ba..84a0c2934e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,7 +67,7 @@ add_library(llama models/gemma-embedding.cpp models/gemma.cpp models/gemma2-iswa.cpp - models/gemma3-iswa.cpp + models/gemma3.cpp models/gemma3n-iswa.cpp models/glm4-moe.cpp models/glm4.cpp diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c3675dbdc41..6b76b2f7423 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1264,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + hparams.f_final_logit_softcapping = 0.0f; + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; // Rnj-1 case 34: type = LLM_TYPE_4B; break; case 48: type = LLM_TYPE_12B; break; case 62: type = LLM_TYPE_27B; break; @@ -7300,7 +7307,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_GEMMA3N: { diff --git a/src/models/gemma3-iswa.cpp b/src/models/gemma3.cpp similarity index 78% rename from src/models/gemma3-iswa.cpp rename to src/models/gemma3.cpp index 839ff6d3d93..ae60ef4790c 100644 --- a/src/models/gemma3-iswa.cpp +++ b/src/models/gemma3.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + float freq_base_l = 0.0f; + float freq_scale_l = 0.0f; + + if constexpr (iswa) { + freq_base_l = model.get_rope_freq_base (cparams, il); + freq_scale_l = model.get_rope_freq_scale(cparams, il); + } else { + freq_base_l = freq_base; + freq_scale_l = freq_scale; + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); @@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "ffn_post_norm", -1); + cb(cur, "ffn_post_norm", il); cur = ggml_add(ctx0, cur, sa_out); @@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll // lm_head cur = build_lora_mm(model.output, cur); + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } + +template struct llm_build_gemma3; +template struct llm_build_gemma3; diff --git a/src/models/models.h b/src/models/models.h index d93601ad06a..6494f545018 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context { llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_gemma3_iswa : public llm_graph_context { - llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_gemma3 : public llm_graph_context { + llm_build_gemma3(const llama_model & model, const llm_graph_params & params); }; struct llm_build_gemma3n_iswa : public llm_graph_context {