Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6098,6 +6098,54 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


# Note: HF checkpoints use architecture "Gemma3ForCausalLM" since a separate
# model implementation is not needed for Transformers. To convert to GGUF,
# change the architecture in config.json to "Rnj1ForCausalLM".
@ModelBase.register("Rnj1ForCausalLM")
class Rnj1Model(TextModel):
model_arch = gguf.MODEL_ARCH.RNJ1
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value

def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

# some default values are not specified in the hparams
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 32768))
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
self.gguf_writer.add_key_length(hparams.get("head_dim", 128))
self.gguf_writer.add_value_length(hparams.get("head_dim", 128))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10_000.0)) # for global layers
assert hparams.get("attn_logit_softcapping") is None
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if "language_model." in name:
name = name.replace("language_model.", "")

# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
if name.endswith("norm.weight"):
data_torch = data_torch + self.norm_shift

return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("Starcoder2ForCausalLM")
class StarCoder2Model(TextModel):
model_arch = gguf.MODEL_ARCH.STARCODER2
Expand Down
20 changes: 20 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum):
GEMMA3 = auto()
GEMMA3N = auto()
GEMMA_EMBEDDING = auto()
RNJ1 = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
Expand Down Expand Up @@ -758,6 +759,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.RNJ1: "rnj1",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
Expand Down Expand Up @@ -1930,6 +1932,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.RNJ1: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ add_library(llama
models/qwen3next.cpp
models/refact.cpp
models/rnd1.cpp
models/rnj1.cpp
models/rwkv6-base.cpp
models/rwkv6.cpp
models/rwkv6qwen2.cpp
Expand Down
22 changes: 22 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_RNJ1, "rnj1" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
Expand Down Expand Up @@ -1225,6 +1226,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_RNJ1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_RNJ1,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
Expand Down
17 changes: 17 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));

} break;
case LLM_ARCH_RNJ1:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 32: type = LLM_TYPE_8B; break;
default: type = LLM_TYPE_UNKNOWN;
}

hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
Expand Down Expand Up @@ -3899,6 +3910,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_RNJ1:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA_EMBEDDING:
{
Expand Down Expand Up @@ -7310,6 +7322,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
} break;
case LLM_ARCH_RNJ1:
{
llm = std::make_unique<llm_build_rnj1>(*this, params);
} break;
case LLM_ARCH_STARCODER2:
{
llm = std::make_unique<llm_build_starcoder2>(*this, params);
Expand Down Expand Up @@ -7767,6 +7783,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_RNJ1:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
Expand Down
4 changes: 4 additions & 0 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ struct llm_build_gemma : public llm_graph_context {
llm_build_gemma(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_rnj1 : public llm_graph_context {
llm_build_rnj1(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_glm4 : public llm_graph_context {
llm_build_glm4(const llama_model & model, const llm_graph_params & params);
};
Expand Down
132 changes: 132 additions & 0 deletions src/models/rnj1.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "models.h"
#include "../llama-impl.h"

// Based off Gemma3 implementation. The main differences are:
// - all layers are global
// - use YaRN for long-context

llm_build_rnj1::llm_build_rnj1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;

ggml_tensor * cur;
ggml_tensor * inpL;

inpL = build_inp_embd(model.tok_embd);

// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) {
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv();

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);

// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);

Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);

Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);

Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);

cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = build_norm(cur,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);

ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);

cur = build_norm(sa_out,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);

// feed-forward network
{
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);

cur = ggml_add(ctx0, cur, sa_out);

cur = build_cvec(cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}
cur = inpL;

cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}