|
| 1 | +from ...util.model import BenchmarkModel |
| 2 | +from torchbenchmark.tasks import NLP |
| 3 | +import torch |
| 4 | +from ..lit_llama import LIT_LLAMA_PATH |
| 5 | +import importlib.util |
| 6 | +import os.path |
| 7 | +import torch.nn as nn |
| 8 | +import sys |
| 9 | +from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict |
| 10 | +from torchbenchmark import REPO_PATH |
| 11 | + |
| 12 | +LIT_LLAMA_PATH = os.path.join(REPO_PATH, "submodules", "lit-llama") |
| 13 | + |
| 14 | +sys.path.insert(0, LIT_LLAMA_PATH) |
| 15 | + |
| 16 | +from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup |
| 17 | +from lit_llama import LLaMA, Tokenizer |
| 18 | + |
| 19 | +class Model(BenchmarkModel): |
| 20 | + task = NLP.LANGUAGE_MODELING |
| 21 | + DEFAULT_EVAL_BSIZE = 1 |
| 22 | + DEFAULT_TRAIN_BSIZE = 4 # micro_batch_size in lora.py |
| 23 | + |
| 24 | + def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): |
| 25 | + super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) |
| 26 | + |
| 27 | + # From finetune/lora.py hyperparameters |
| 28 | + lora_r = 8 |
| 29 | + lora_alpha = 16 |
| 30 | + lora_dropout = 0.05 |
| 31 | + |
| 32 | + checkpoint_path = os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/7B/lit-llama.pth") |
| 33 | + if not os.path.exists(checkpoint_path): |
| 34 | + raise NotImplementedError("checkpoint doesn't exist") |
| 35 | + with lazy_load(checkpoint_path) as checkpoint, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True): |
| 36 | + name = llama_model_lookup(checkpoint) |
| 37 | + |
| 38 | + with EmptyInitOnDevice(device=device): |
| 39 | + model = LLaMA.from_name(name) |
| 40 | + # LoRA weights won't be in base checkpoint |
| 41 | + model.load_state_dict(checkpoint, strict=False) |
| 42 | + |
| 43 | + mark_only_lora_as_trainable(model) |
| 44 | + |
| 45 | + self.model = model |
| 46 | + self.seq_len = 32 |
| 47 | + self.max_seq_len = 64 |
| 48 | + self.example_inputs = ( |
| 49 | + torch.ones([self.batch_size, self.seq_len], dtype=torch.int32, device=self.device), |
| 50 | + self.max_seq_len, |
| 51 | + ) |
| 52 | + |
| 53 | + |
| 54 | + def get_module(self): |
| 55 | + return self.model, self.example_inputs |
| 56 | + |
| 57 | + def train(self): |
| 58 | + logits = self.model(*self.example_inputs) |
| 59 | + logits.sum().backward() |
| 60 | + # meh this sucks |
| 61 | + |
| 62 | + def eval(self): |
| 63 | + self.model.eval() |
| 64 | + with torch.no_grad(): |
| 65 | + logits = self.model(*self.example_inputs) |
| 66 | + return (logits,) |
0 commit comments