Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def load_adapter(
)
peft_config.inference_mode = not is_trainable

if peft_config.peft_type != PeftType.LORA:
if hotswap and (peft_config.peft_type != PeftType.LORA):
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")

if not hotswap:
Expand Down
54 changes: 54 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,60 @@ def test_peft_pipeline_no_warning(self):
# Generate text to verify pipeline works
_ = lora_generator(text, max_new_tokens=20)

def test_non_lora_load_adapter(self):
"""
Check that loading a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods.
"""
from peft import LoKrConfig, get_peft_model

inputs = torch.randint(0, 100, (1, 10)).to(torch_device)
atol, rtol = 1e-4, 1e-4

for model_id in self.transformers_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)
with torch.inference_mode():
output_base = model(inputs).logits

peft_config = LoKrConfig(init_weights=False)
peft_model = get_peft_model(model, peft_config)
with torch.inference_mode():
output_peft = peft_model(inputs).logits

# sanity check: should be different
assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol)

with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
del model, peft_model

model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
with torch.inference_mode():
output_transformers = model(inputs).logits
self.assertTrue(torch.allclose(output_peft, output_transformers, atol=atol, rtol=rtol))

def test_non_lora_add_adapter(self):
"""
Check that adding a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods.
"""
from peft import LoKrConfig

inputs = torch.randint(0, 100, (1, 10)).to(torch_device)
atol, rtol = 1e-4, 1e-4

for model_id in self.transformers_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)
with torch.inference_mode():
output_base = model(inputs).logits

peft_config = LoKrConfig(init_weights=False)
model.add_adapter(peft_config)
with torch.inference_mode():
output_peft = model(inputs).logits
# should be different
assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol)


@require_peft
@require_torch
Expand Down