diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..74b0c917b 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1903,7 +1903,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"checkpoint-{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) @@ -1911,7 +1911,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"step{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token, **kwargs, ) @@ -1924,21 +1924,21 @@ def get_pretrained_state_dict( elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "t5" in official_model_name: hf_model = T5ForConditionalGeneration.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) else: hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, )