diff --git a/.gitignore b/.gitignore index 978e887aa..4fa33c0b2 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ docs/build docs/source/generated **.orig .venv + diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 025b43793..e97e0e9da 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -143,7 +143,6 @@ def __init__( ) self.cfg = HookedTransformerConfig.unwrap(cfg) - if tokenizer is not None: self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) elif self.cfg.tokenizer_name is not None: @@ -161,13 +160,18 @@ def __init__( if "phi" in self.cfg.tokenizer_name.lower(): use_fast = False huggingface_token = os.environ.get("HF_TOKEN", "") + add_bos_token = self.cfg.original_architecture not in [ + "OlmoForCausalLM", + "OlmoeForCausalLM", + "Olmo2ForCausalLM", + ] self.set_tokenizer( AutoTokenizer.from_pretrained( self.cfg.tokenizer_name, - add_bos_token=True, trust_remote_code=self.cfg.trust_remote_code, use_fast=use_fast, token=huggingface_token if len(huggingface_token) > 0 else None, + add_bos_token=add_bos_token, ), default_padding_side=default_padding_side, ) @@ -734,7 +738,14 @@ def set_tokenizer( # tokenizers like LlamaTokenizer are different when bos token is automatically/manually # prepended, and add_bos_token cannot be dynamically controlled after initialization # (https://github.com/huggingface/transformers/issues/25886). - tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) + if self.cfg.original_architecture not in [ + "OlmoForCausalLM", + "OlmoeForCausalLM", + "Olmo2ForCausalLM", + ]: + tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) + else: + tokenizer_with_bos = tokenizer self.tokenizer = tokenizer_with_bos self.tokenizer.padding_side = default_padding_side @@ -1798,18 +1809,18 @@ def fold_layer_norm( if not self.cfg.final_rms and fold_biases: # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm # pre unembed. - state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + ( - state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None] + state_dict["unembed.b_U"] = state_dict["unembed.b_U"] + ( + state_dict["unembed.W_U"] * state_dict["ln_final.b"][:, None] ).sum(dim=-2) - del state_dict[f"ln_final.b"] + del state_dict["ln_final.b"] - state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] - del state_dict[f"ln_final.w"] + state_dict["unembed.W_U"] = state_dict["unembed.W_U"] * state_dict["ln_final.w"][:, None] + del state_dict["ln_final.w"] if center_weights: # Center the weights that read in from the LayerNormPre - state_dict[f"unembed.W_U"] -= einops.reduce( - state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" + state_dict["unembed.W_U"] -= einops.reduce( + state_dict["unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" ) return state_dict @@ -1821,13 +1832,17 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): W_out. This is done by subtracting the mean of the weights from the weights themselves. This is done in-place. See fold_layer_norm for more details. """ - state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( - -1, keepdim=True - ) - if self.cfg.positional_embedding_type != "rotary": - state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ - "pos_embed.W_pos" - ].mean(-1, keepdim=True) + if self.cfg.original_architecture == "Olmo2ForCausalLM": + print("Not centering embedding weights for Olmo2ForCausalLM") + pass # should not because input of attn of 1st layer is not normed + else: + state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( + -1, keepdim=True + ) + if self.cfg.positional_embedding_type != "rotary": + state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ + "pos_embed.W_pos" + ].mean(-1, keepdim=True) for l in range(self.cfg.n_layers): state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ f"blocks.{l}.attn.W_O" diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 9cf16f578..14c93424f 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -192,8 +192,7 @@ class HookedTransformerConfig: NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that affects the rate of change between low and high-frequency interpolation strategies. Defaults to 8.0. - - + norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer. """ n_layers: int @@ -264,6 +263,7 @@ class HookedTransformerConfig: NTK_by_parts_high_freq_factor: float = 4.0 NTK_by_parts_factor: float = 8.0 NTK_original_ctx_len: int = 8192 + norm_topk_prob: bool = False def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 0aee43814..5e9eaca42 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -154,6 +154,18 @@ def __init__( # will be overwritten by the child T5Attention class self.has_relative_attention_bias = False + if ( + self.cfg.original_architecture == "OlmoeForCausalLM" + or self.cfg.original_architecture == "Olmo2ForCausalLM" + ): + self.q_norm = RMSNorm(self.cfg, self.cfg.d_model) + k_norm_dim = ( + self.cfg.d_model + if self.cfg.original_architecture == "Olmo2ForCausalLM" + else self.cfg.d_head * self.cfg.n_key_value_heads + ) + self.k_norm = RMSNorm(self.cfg, k_norm_dim) + @property def OV(self) -> FactoredMatrix: """ @@ -209,6 +221,32 @@ def forward( q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) + # OLMoE uses QK-norm. + if ( + self.cfg.original_architecture == "OlmoeForCausalLM" + or self.cfg.original_architecture == "Olmo2ForCausalLM" + ): + q = einops.rearrange( + self.q_norm( + einops.rearrange( + q, + "batch pos head_index d_head -> batch pos (head_index d_head)", + ) + ), + "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", + head_index=q.shape[2], + ) + k = einops.rearrange( + self.k_norm( + einops.rearrange( + k, + "batch pos head_index d_head -> batch pos (head_index d_head)", + ) + ), + "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", + head_index=k.shape[2], + ) + if past_kv_cache_entry is not None: # Appends the new keys and values to the cached values, and automatically updates the cache kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) @@ -244,9 +282,10 @@ def forward( ) # Take the last query_ctx positions so it also works with past_kv_cache - attn_scores += self.alibi[ - :, -query_ctx:, :key_ctx - ] # [batch, head_index, query_pos, key_pos] + if self.alibi is not None: # Add None check + attn_scores += self.alibi[ + :, -query_ctx:, :key_ctx + ] # [batch, head_index, query_pos, key_pos] elif self.cfg.positional_embedding_type == "relative_positional_bias": if position_bias is None: if self.has_relative_attention_bias: @@ -260,7 +299,8 @@ def forward( device=attn_scores.device, ) - attn_scores += position_bias + if position_bias is not None: # Add None check + attn_scores += position_bias if self.cfg.attention_dir == "causal": # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. attn_scores = self.apply_causal_mask( diff --git a/transformer_lens/components/mlps/moe.py b/transformer_lens/components/mlps/moe.py index e01f25ee9..c343fd141 100644 --- a/transformer_lens/components/mlps/moe.py +++ b/transformer_lens/components/mlps/moe.py @@ -88,7 +88,8 @@ def forward( # both are [batch, pos, experts_per_token] weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float)) weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1) - weights /= weights.sum(dim=-1, keepdim=True) + if self.cfg.norm_topk_prob: + weights /= weights.sum(dim=-1, keepdim=True) expert_indices = self.hook_expert_indices(expert_indices) weights = weights.to(x.dtype) diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index dcce1586a..4d45e9b50 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -153,26 +153,37 @@ def forward( key_input = attn_in value_input = attn_in - attn_out = ( - # hook the residual stream states that are used to calculate the - # queries, keys and values, independently. - # Then take the layer norm of these inputs, and pass these to the attention module. - self.attn( - query_input=self.ln1(query_input) - + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), - key_input=self.ln1(key_input) - + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), - value_input=self.ln1(value_input), + if self.cfg.original_architecture == "Olmo2ForCausalLM": + attn_out = self.attn( + query_input=query_input, + key_input=key_input, + value_input=value_input, past_kv_cache_entry=past_kv_cache_entry, attention_mask=attention_mask, ) - ) # [batch, pos, d_model] + else: + attn_out = ( + # hook the residual stream states that are used to calculate the + # queries, keys and values, independently. + # Then take the layer norm of these inputs, and pass these to the attention module. + self.attn( + query_input=self.ln1(query_input) + + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), + key_input=self.ln1(key_input) + + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), + value_input=self.ln1(value_input), + past_kv_cache_entry=past_kv_cache_entry, + attention_mask=attention_mask, + ) + ) # [batch, pos, d_model] if self.cfg.use_normalization_before_and_after: # If we use LayerNorm both before and after, then apply the second LN after the layer # and before the hook. We do it before the hook so hook_attn_out captures "that which # is added to the residual stream" attn_out = self.ln1_post(attn_out) attn_out = self.hook_attn_out(attn_out) + if self.cfg.original_architecture == "Olmo2ForCausalLM": + attn_out = self.ln1(attn_out) if resid_pre.device != attn_out.device: resid_pre = resid_pre.to(attn_out.device) @@ -182,8 +193,12 @@ def forward( mlp_in = ( resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) ) - normalized_resid_mid = self.ln2(mlp_in) - mlp_out = self.apply_mlp(normalized_resid_mid) + if self.cfg.original_architecture == "Olmo2ForCausalLM": + mlp_out = self.apply_mlp(mlp_in) + mlp_out = self.ln2(mlp_out) + else: + normalized_resid_mid = self.ln2(mlp_in) + mlp_out = self.apply_mlp(normalized_resid_mid) resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] elif self.cfg.parallel_attn_mlp: # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..37e4ce5c8 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -37,6 +37,9 @@ convert_neel_solu_old_weights, convert_neo_weights, convert_neox_weights, + convert_olmo2_weights, + convert_olmo_weights, + convert_olmoe_weights, convert_opt_weights, convert_phi3_weights, convert_phi_weights, @@ -263,6 +266,27 @@ "google-t5/t5-base", "google-t5/t5-large", "ai-forever/mGPT", + "allenai/OLMo-1B-hf", + "allenai/OLMo-7B-hf", + "allenai/OLMo-7B-0724-hf", + "allenai/OLMo-7B-0724-SFT-hf", + "allenai/OLMo-7B-0724-Instruct-hf", + "allenai/OLMo-7B-0424-hf", + "allenai/OLMo-7B-Twin-2T-hf", + "allenai/OLMo-1B-0724-hf", + "allenai/OLMo-7B-Instruct-hf", + "allenai/OLMo-7B-SFT-hf", + "allenai/OLMoE-1B-7B-0924", + "allenai/OLMoE-1B-7B-0924-SFT", + "allenai/OLMoE-1B-7B-0924-Instruct", + "allenai/OLMo-2-0425-1B", + "allenai/OLMo-2-0425-1B-SFT", + "allenai/OLMo-2-0425-1B-DPO", + "allenai/OLMo-2-0425-1B-Instruct", + "allenai/OLMo-2-1124-7B", + "allenai/OLMo-2-1124-7B-SFT", + "allenai/OLMo-2-1124-7B-DPO", + "allenai/OLMo-2-1124-7B-Instruct", ] """Official model names for models on HuggingFace.""" @@ -1563,6 +1587,102 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "final_rms": True, "use_normalization_before_and_after": True, } + elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"): + cfg_dict = { + "d_model": 2048, + "d_head": 128, + "n_heads": 16, + "d_mlp": 8192, + "n_layers": 16, + "n_ctx": 2048, + "eps": 1e-05, + "d_vocab": 50304, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "LN", + "rotary_base": 10000.0, + "attn_types": ["global"] * 16, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"): + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 11008, + "n_layers": 32, + "n_ctx": 2048, + "eps": 1e-05, + "d_vocab": 50304, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "LN", + "rotary_base": 10000.0, + "attn_types": ["global"] * 32, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif official_model_name.startswith("allenai/OLMo-2-0425-1B"): + cfg_dict = { + "d_model": 2048, + "d_head": 128, + "n_heads": 16, + "d_mlp": 8192, + "n_layers": 16, + "n_ctx": 4096, + "eps": 1e-06, + "d_vocab": 100352, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 500000.0, + "attn_types": ["global"] * 16, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif official_model_name.startswith("allenai/OLMo-2-1124-7B"): + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 11008, + "n_layers": 32, + "n_ctx": 4096, + "eps": 1e-06, + "d_vocab": 100352, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 500000.0, + "attn_types": ["global"] * 32, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif architecture == "OlmoeForCausalLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "num_experts": hf_config.num_experts, + "experts_per_token": hf_config.num_experts_per_tok, + "norm_topk_prob": hf_config.norm_topk_prob, + "n_key_value_heads": hf_config.num_key_value_heads, + "rotary_base": hf_config.rope_theta, + "tie_word_embeddings": hf_config.tie_word_embeddings, + "initializer_range": hf_config.initializer_range, + "positional_embedding_type": "rotary", + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "final_rms": True, + "gated_mlp": True, + "normalization_type": "LN", + } elif architecture == "T5ForConditionalGeneration": cfg_dict = { "d_model": hf_config.d_model, @@ -1986,6 +2106,12 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "OlmoForCausalLM": + state_dict = convert_olmo_weights(hf_model, cfg) + elif cfg.original_architecture == "Olmo2ForCausalLM": + state_dict = convert_olmo2_weights(hf_model, cfg) + elif cfg.original_architecture == "OlmoeForCausalLM": + state_dict = convert_olmoe_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..bba841a29 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,6 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .olmo import convert_olmo_weights +from .olmoe import convert_olmoe_weights +from .olmo2 import convert_olmo2_weights diff --git a/transformer_lens/pretrained/weight_conversions/olmo.py b/transformer_lens/pretrained/weight_conversions/olmo.py new file mode 100644 index 000000000..38b4e0800 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmo.py @@ -0,0 +1,50 @@ +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_olmo_weights(olmo, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.d_mlp is not None + + state_dict["embed.W_E"] = olmo.model.embed_tokens.weight + for l in range(cfg.n_layers): + olmo_layer = olmo.model.layers[l] + + W_Q = olmo_layer.self_attn.q_proj.weight + W_K = olmo_layer.self_attn.k_proj.weight + W_V = olmo_layer.self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) + W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) + W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + + W_O = olmo_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_in"] = olmo_layer.mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = olmo_layer.mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = olmo_layer.mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln1.w"] = torch.ones(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln2.w"] = torch.ones(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict["ln_final.w"] = torch.ones(cfg.d_model, dtype=cfg.dtype) + state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict["unembed.W_U"] = olmo.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict diff --git a/transformer_lens/pretrained/weight_conversions/olmo2.py b/transformer_lens/pretrained/weight_conversions/olmo2.py new file mode 100644 index 000000000..1696a5dc2 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmo2.py @@ -0,0 +1,56 @@ +import einops +import torch +from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_olmo2_weights(olmo2: Olmo2ForCausalLM, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.d_mlp is not None + + state_dict["embed.W_E"] = olmo2.model.embed_tokens.weight + + for l in range(cfg.n_layers): + olmo2_layer = olmo2.model.layers[l] # type: ignore + + W_Q = olmo2_layer.self_attn.q_proj.weight + W_K = olmo2_layer.self_attn.k_proj.weight + W_V = olmo2_layer.self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + state_dict[f"blocks.{l}.attn.q_norm.w"] = olmo2_layer.self_attn.q_norm.weight + state_dict[f"blocks.{l}.attn.k_norm.w"] = olmo2_layer.self_attn.k_norm.weight + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + + W_O = olmo2_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln1.w"] = olmo2_layer.post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = olmo2_layer.mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = olmo2_layer.mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = olmo2_layer.mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = olmo2_layer.post_feedforward_layernorm.weight + + state_dict["ln_final.w"] = olmo2.model.norm.weight + + state_dict["unembed.W_U"] = olmo2.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict diff --git a/transformer_lens/pretrained/weight_conversions/olmoe.py b/transformer_lens/pretrained/weight_conversions/olmoe.py new file mode 100644 index 000000000..d850dfbbe --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmoe.py @@ -0,0 +1,66 @@ +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.n_key_value_heads is not None + assert cfg.d_mlp is not None + assert cfg.num_experts is not None + + state_dict["embed.W_E"] = olmoe.model.embed_tokens.weight + + for l in range(cfg.n_layers): + olmoe_layer = olmoe.model.layers[l] + state_dict[f"blocks.{l}.ln1.w"] = olmoe_layer.input_layernorm.weight + + W_Q = olmoe_layer.self_attn.q_proj.weight + W_K = olmoe_layer.self_attn.k_proj.weight + W_V = olmoe_layer.self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + state_dict[f"blocks.{l}.attn.q_norm.w"] = olmoe_layer.self_attn.q_norm.weight + state_dict[f"blocks.{l}.attn.k_norm.w"] = olmoe_layer.self_attn.k_norm.weight + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = olmoe_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = olmoe_layer.post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_gate.weight"] = olmoe_layer.mlp.gate.weight + + for e in range(cfg.num_experts): + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = olmoe_layer.mlp.experts[ + e + ].up_proj.weight + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = olmoe_layer.mlp.experts[ + e + ].gate_proj.weight + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = olmoe_layer.mlp.experts[ + e + ].down_proj.weight + + state_dict["ln_final.w"] = olmoe.model.norm.weight + + state_dict["unembed.W_U"] = olmoe.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict