diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 4c1c12a4058e..32c8709c5fe2 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -279,7 +279,7 @@ def load_adapter( ) peft_config.inference_mode = not is_trainable - if peft_config.peft_type != PeftType.LORA: + if hotswap and (peft_config.peft_type != PeftType.LORA): raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.") if not hotswap: diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 1e0e2335067d..0f4e16964265 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -889,6 +889,60 @@ def test_peft_pipeline_no_warning(self): # Generate text to verify pipeline works _ = lora_generator(text, max_new_tokens=20) + def test_non_lora_load_adapter(self): + """ + Check that loading a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods. + """ + from peft import LoKrConfig, get_peft_model + + inputs = torch.randint(0, 100, (1, 10)).to(torch_device) + atol, rtol = 1e-4, 1e-4 + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + with torch.inference_mode(): + output_base = model(inputs).logits + + peft_config = LoKrConfig(init_weights=False) + peft_model = get_peft_model(model, peft_config) + with torch.inference_mode(): + output_peft = peft_model(inputs).logits + + # sanity check: should be different + assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol) + + with tempfile.TemporaryDirectory() as tmpdirname: + peft_model.save_pretrained(tmpdirname) + del model, peft_model + + model = transformers_class.from_pretrained(tmpdirname).to(torch_device) + with torch.inference_mode(): + output_transformers = model(inputs).logits + self.assertTrue(torch.allclose(output_peft, output_transformers, atol=atol, rtol=rtol)) + + def test_non_lora_add_adapter(self): + """ + Check that adding a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods. + """ + from peft import LoKrConfig + + inputs = torch.randint(0, 100, (1, 10)).to(torch_device) + atol, rtol = 1e-4, 1e-4 + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + with torch.inference_mode(): + output_base = model(inputs).logits + + peft_config = LoKrConfig(init_weights=False) + model.add_adapter(peft_config) + with torch.inference_mode(): + output_peft = model(inputs).logits + # should be different + assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol) + @require_peft @require_torch