From eccc5ec2de3c6a719bd84e7d39706e64803b772c Mon Sep 17 00:00:00 2001 From: Rishab Alagharu Date: Mon, 5 Jan 2026 17:47:40 -0500 Subject: [PATCH] Fixed deprecated torch_dtype argument from HuggingFace transformers: replaced with dtype to avoid warning --- transformer_lens/loading_from_pretrained.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..74b0c917b 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1903,7 +1903,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"checkpoint-{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) @@ -1911,7 +1911,7 @@ def get_pretrained_state_dict( hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"step{cfg.checkpoint_value}", - torch_dtype=dtype, + dtype=dtype, token=huggingface_token, **kwargs, ) @@ -1924,21 +1924,21 @@ def get_pretrained_state_dict( elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) elif "t5" in official_model_name: hf_model = T5ForConditionalGeneration.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, ) else: hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, - torch_dtype=dtype, + dtype=dtype, token=huggingface_token if len(huggingface_token) > 0 else None, **kwargs, )