Skip to content

Commit 02ff72b

Browse files
ezyangfacebook-github-bot
authored andcommitted
Add lit-llama benchmarks (logits, autoregressive generation, lora fine tuning) (#1730)
Summary: Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #1730 Reviewed By: xuzhao9 Differential Revision: D47295535 Pulled By: ezyang fbshipit-source-id: 9e5998569b6dc58c6c918c4caaf1e6c896600430
1 parent 9df5215 commit 02ff72b

File tree

12 files changed

+258
-0
lines changed

12 files changed

+258
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "submodules/FAMBench"]
22
path = submodules/FAMBench
33
url = https://github.com/facebookresearch/FAMBench.git
4+
[submodule "submodules/lit-llama"]
5+
path = submodules/lit-llama
6+
url = https://github.com/Lightning-AI/lit-llama.git

submodules/lit-llama

Submodule lit-llama added at 8aa65ba
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from ...util.model import BenchmarkModel
2+
from torchbenchmark.tasks import NLP
3+
import torch
4+
import os
5+
from torchbenchmark import add_path, REPO_PATH
6+
import sys
7+
import lightning as L
8+
9+
LIT_LLAMA_PATH = os.path.join(REPO_PATH, "submodules", "lit-llama")
10+
11+
with add_path(LIT_LLAMA_PATH):
12+
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
13+
from lit_llama import LLaMA, Tokenizer
14+
15+
class Model(BenchmarkModel):
16+
task = NLP.LANGUAGE_MODELING
17+
DEFAULT_EVAL_BSIZE = 1
18+
19+
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
20+
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
21+
22+
checkpoint_path = os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/7B/lit-llama.pth")
23+
if not os.path.exists(checkpoint_path):
24+
raise NotImplementedError("checkpoint doesn't exist")
25+
with lazy_load(checkpoint_path) as checkpoint:
26+
name = llama_model_lookup(checkpoint)
27+
28+
with EmptyInitOnDevice(device=device):
29+
model = LLaMA.from_name(name)
30+
model.load_state_dict(checkpoint)
31+
32+
self.model = model
33+
self.seq_len = 32
34+
self.max_seq_len = 64
35+
self.example_inputs = (
36+
torch.ones([self.batch_size, self.seq_len], dtype=torch.int32, device=self.device),
37+
self.max_seq_len,
38+
torch.arange(self.seq_len, dtype=torch.int64, device=self.device) # positions
39+
)
40+
41+
42+
def get_module(self):
43+
return self.model, self.example_inputs
44+
45+
def train(self):
46+
return NotImplementedError("you will OOM trying to train directly")
47+
48+
def eval(self):
49+
self.model.eval()
50+
with torch.no_grad():
51+
logits = self.model(*self.example_inputs)
52+
return (logits,)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torchbenchmark.util.framework.lit_llama import install_lit_llama
2+
3+
if __name__ == '__main__':
4+
install_lit_llama()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
devices:
2+
NVIDIA A100-SXM4-40GB:
3+
eval_batch_size: 32
4+
eval_benchmark: false
5+
eval_deterministic: false
6+
eval_nograd: true
7+
train_benchmark: false
8+
train_deterministic: false
9+
not_implemented:
10+
- test: eval
11+
- test: example
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from .. import lit_llama as lit_llama
2+
from ..lit_llama import LIT_LLAMA_PATH
3+
import importlib.util
4+
import os.path
5+
import torch.nn as nn
6+
import sys
7+
from lit_llama import Tokenizer
8+
9+
def import_from_file_path(module_name, file_path):
10+
spec = importlib.util.spec_from_file_location(module_name, file_path)
11+
module = importlib.util.module_from_spec(spec)
12+
spec.loader.exec_module(module)
13+
sys.modules[module_name] = module
14+
return module
15+
16+
lit_llama_generate = import_from_file_path("lit_llama_generate", os.path.join(LIT_LLAMA_PATH, 'generate.py'))
17+
18+
class GenerationWrapper(nn.Module):
19+
def __init__(self, model):
20+
super().__init__()
21+
self.model = model
22+
23+
def forward(self, idx, max_new_tokens):
24+
return lit_llama_generate.generate(self.model, idx, max_new_tokens)
25+
26+
class Model(lit_llama.Model):
27+
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
28+
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
29+
self.model = GenerationWrapper(self.model)
30+
tokenizer = Tokenizer(os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/tokenizer.model"))
31+
# max_new_tokens matches lit-llama/generate.py
32+
self.example_inputs = (tokenizer.encode("The meaning of life is", bos=True, eos=False, device=device), 50)
33+
34+
def train(self):
35+
return NotImplementedError("cannot train on autoregressive generation")
36+
37+
def eval(self):
38+
self.model.eval()
39+
with torch.no_grad():
40+
y = self.model(*self.example_inputs)
41+
return (y,)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torchbenchmark.util.framework.lit_llama import install_lit_llama
2+
3+
if __name__ == '__main__':
4+
install_lit_llama()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
devices:
2+
NVIDIA A100-SXM4-40GB:
3+
eval_batch_size: 32
4+
eval_benchmark: false
5+
eval_deterministic: false
6+
eval_nograd: true
7+
train_benchmark: false
8+
train_deterministic: false
9+
not_implemented:
10+
- test: eval
11+
- test: example
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from ...util.model import BenchmarkModel
2+
from torchbenchmark.tasks import NLP
3+
import torch
4+
from ..lit_llama import LIT_LLAMA_PATH
5+
import importlib.util
6+
import os.path
7+
import torch.nn as nn
8+
import sys
9+
from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
10+
from torchbenchmark import REPO_PATH
11+
12+
LIT_LLAMA_PATH = os.path.join(REPO_PATH, "submodules", "lit-llama")
13+
14+
sys.path.insert(0, LIT_LLAMA_PATH)
15+
16+
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
17+
from lit_llama import LLaMA, Tokenizer
18+
19+
class Model(BenchmarkModel):
20+
task = NLP.LANGUAGE_MODELING
21+
DEFAULT_EVAL_BSIZE = 1
22+
DEFAULT_TRAIN_BSIZE = 4 # micro_batch_size in lora.py
23+
24+
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
25+
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
26+
27+
# From finetune/lora.py hyperparameters
28+
lora_r = 8
29+
lora_alpha = 16
30+
lora_dropout = 0.05
31+
32+
checkpoint_path = os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/7B/lit-llama.pth")
33+
if not os.path.exists(checkpoint_path):
34+
raise NotImplementedError("checkpoint doesn't exist")
35+
with lazy_load(checkpoint_path) as checkpoint, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
36+
name = llama_model_lookup(checkpoint)
37+
38+
with EmptyInitOnDevice(device=device):
39+
model = LLaMA.from_name(name)
40+
# LoRA weights won't be in base checkpoint
41+
model.load_state_dict(checkpoint, strict=False)
42+
43+
mark_only_lora_as_trainable(model)
44+
45+
self.model = model
46+
self.seq_len = 32
47+
self.max_seq_len = 64
48+
self.example_inputs = (
49+
torch.ones([self.batch_size, self.seq_len], dtype=torch.int32, device=self.device),
50+
self.max_seq_len,
51+
)
52+
53+
54+
def get_module(self):
55+
return self.model, self.example_inputs
56+
57+
def train(self):
58+
logits = self.model(*self.example_inputs)
59+
logits.sum().backward()
60+
# meh this sucks
61+
62+
def eval(self):
63+
self.model.eval()
64+
with torch.no_grad():
65+
logits = self.model(*self.example_inputs)
66+
return (logits,)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torchbenchmark.util.framework.lit_llama import install_lit_llama
2+
3+
if __name__ == '__main__':
4+
install_lit_llama()

0 commit comments

Comments
 (0)