Skip to content

Commit 674955d

Browse files
committed
create a new accuracy eval script for official README.md eval accuracy
Summary: Creates a standalone eval script for generating accuracy metrics for quantization README.md, based on the HuggingFace model definition of LLaMa 3.1 8B Why new script? 1. the current `prod` script in https://github.com/pytorch/ao/blob/main/torchao/_models/llama/eval.py uses a custom model definition, this was pre-HF integration, it's better to use HF's model definition now 2. we have HummingBird scripts in https://github.com/pytorch/ao/tree/40c4f44677ae11166c3dcfbb9189cfa78789390c/.github/scripts/torchao_model_releases, but they seem pretty verbose and hard to use/modify 3. we have https://github.com/pytorch/ao/blob/main/benchmarks/_models/eval_hf_models.py, I copy-pasted and modified this for the current PR. The script above didn't work as is for various reasons, and also seemed to be hard to use/modify, for main README.md it's important to have a very simple standalone script. We should probably do a pass on the naming before landing. Future work: 1. add metrics for `int4_weight_only_hqq` (need to run on A100) 2. add metrics for `mxfp8` and `nvfp4` (need to run on B200) 3. make the parsing of logs automated 4. also add a similar script for performance benchmarks, using vllm 5. delete https://github.com/pytorch/ao/blob/main/torchao/_models/llama/ Test Plan: ``` // debug run on small model with-proxy time ./benchmarks/quantization/eval_accuracy_for_readme.sh facebook/opt-125m // real run with-proxy time ./benchmarks/quantization/eval_accuracy_for_readme.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 39c1d72 ghstack-comment-id: 3618394399 Pull-Request: #3449
1 parent 69ce0fd commit 674955d

File tree

3 files changed

+248
-24
lines changed

3 files changed

+248
-24
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import subprocess
9+
10+
import torch
11+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
12+
13+
from torchao.quantization import (
14+
Float8DynamicActivationFloat8WeightConfig,
15+
Float8DynamicActivationInt4WeightConfig,
16+
Int4WeightOnlyConfig,
17+
Int8DynamicActivationInt8WeightConfig,
18+
Int8WeightOnlyConfig,
19+
PerRow,
20+
)
21+
22+
23+
def string_to_config(s):
24+
if s is None:
25+
return None
26+
elif s == "float8_rowwise":
27+
return Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
28+
elif s == "int4_weight_float8_rowwise_activation":
29+
return Float8DynamicActivationInt4WeightConfig()
30+
elif s == "int4_weight_only_hqq":
31+
return Int4WeightOnlyConfig(
32+
group_size=32,
33+
int4_packing_format="tile_packed_to_4d",
34+
int4_choose_qparams_algorithm="hqq",
35+
)
36+
elif s == "int8_weight_only":
37+
return Int8WeightOnlyConfig()
38+
elif s == "int8":
39+
return Int8DynamicActivationInt8WeightConfig()
40+
else:
41+
raise AssertionError(f"unsupported {s}")
42+
43+
44+
def quantize_model_and_save(model_id, quant_config, output_dir="results"):
45+
"""Quantize the model and save it to the output directory."""
46+
print("Quantizing model with config: ", quant_config)
47+
if quant_config is None:
48+
quantization_config = None
49+
else:
50+
quantization_config = TorchAoConfig(quant_type=quant_config)
51+
quantized_model = AutoModelForCausalLM.from_pretrained(
52+
model_id,
53+
device_map="auto",
54+
dtype=torch.bfloat16,
55+
quantization_config=quantization_config,
56+
)
57+
tokenizer = AutoTokenizer.from_pretrained(model_id)
58+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
59+
tokenizer.save_pretrained(output_dir, safe_serialization=False)
60+
return quantized_model, tokenizer
61+
62+
63+
def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8):
64+
"""Run the lm_eval command using subprocess."""
65+
tasks_str = ",".join(tasks_list)
66+
command = [
67+
"lm_eval",
68+
"--model",
69+
"hf",
70+
"--model_args",
71+
f"pretrained={model_dir}",
72+
"--tasks",
73+
f"{tasks_str}",
74+
"--device",
75+
f"{device}",
76+
"--batch_size",
77+
f"{batch_size}",
78+
"--output_path",
79+
f"{model_dir}/lm_eval_outputs/",
80+
]
81+
subprocess.run(command, check=True)
82+
83+
84+
def get_size_of_dir(model_output_dir):
85+
# get dir size from shell, to skip complexity of dealing with tensor
86+
# subclasses
87+
result = subprocess.run(
88+
["du", "-sb", model_output_dir], capture_output=True, text=True
89+
)
90+
size = int(result.stdout.split()[0])
91+
return size
92+
93+
94+
def run(
95+
model_id: str,
96+
quant_recipe_name: str | None,
97+
tasks,
98+
device,
99+
batch_size,
100+
model_output_dir,
101+
):
102+
print(f"\nRunning {model_id=} with {quant_recipe_name=}\n")
103+
model_name = model_id.split("/")[-1]
104+
model_output_dir = (
105+
f"benchmarks/data/quantized_model/{model_name}-{quant_recipe_name}"
106+
)
107+
quant_config = string_to_config(quant_recipe_name)
108+
quantized_model, tokenizer = quantize_model_and_save(
109+
model_id, quant_config=quant_config, output_dir=model_output_dir
110+
)
111+
print(quantized_model)
112+
113+
model_size = get_size_of_dir(model_output_dir) / 1e9
114+
print(f"checkpoint size: {model_size} GB")
115+
116+
run_lm_eval(
117+
model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size
118+
)
119+
print("done\n")
120+
121+
122+
if __name__ == "__main__":
123+
try:
124+
import lm_eval # noqa: F401
125+
except:
126+
print(
127+
"lm_eval is required to run this script. Please install it using pip install lm-eval."
128+
)
129+
exit(0)
130+
131+
# Set up argument parser
132+
parser = argparse.ArgumentParser(
133+
description="Quantize a model and evaluate its throughput."
134+
)
135+
parser.add_argument(
136+
"--model_id",
137+
type=str,
138+
default="meta-llama/Llama-3.1-8B",
139+
help="The model ID to use.",
140+
)
141+
parser.add_argument(
142+
"--quant_recipe_name",
143+
type=str,
144+
default=None,
145+
help="The quantization recipe to use.",
146+
)
147+
parser.add_argument(
148+
"--tasks",
149+
nargs="+",
150+
type=str,
151+
default=["wikitext"],
152+
help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2",
153+
)
154+
parser.add_argument(
155+
"--device", type=str, default="cuda:0", help="Device to run the model on."
156+
)
157+
parser.add_argument(
158+
"--batch_size", type=str, default="auto", help="Batch size for lm_eval."
159+
)
160+
parser.add_argument(
161+
"--output_dir",
162+
type=str,
163+
default="quantized_models",
164+
help="Output directory for quantized model.",
165+
)
166+
args = parser.parse_args()
167+
168+
# Use parsed arguments
169+
run(
170+
model_id=args.model_id,
171+
quant_recipe_name=args.quant_recipe_name,
172+
tasks=args.tasks,
173+
device=args.device,
174+
batch_size=args.batch_size,
175+
model_output_dir=args.output_dir,
176+
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
# Get model_id as positional argument (optional)
6+
MODEL_ID="${1:-meta-llama/Llama-3.1-8B}"
7+
8+
# Get log file as first positional argument (optional)
9+
LOG_FILE="${2:-benchmarks/data/eval_accuracy_for_readme_log.txt}"
10+
11+
# Build the base command arguments
12+
BASE_ARGS="--tasks wikitext winogrande"
13+
if [[ -n "$MODEL_ID" ]]; then
14+
BASE_ARGS="--model_id $MODEL_ID $BASE_ARGS"
15+
fi
16+
17+
# baseline
18+
# note: the -u flag is to prevent python from buffering stdout and stderr
19+
# and make the output log file be in chronological order
20+
time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS 2>&1 | tee "$LOG_FILE"
21+
22+
# quantized recipes
23+
# note:
24+
# * `int4_weight_float8_rowwise_activation` doesn't work with dtype_map auto: https://gist.github.com/vkuzo/6b128681b628744d445c553cdeac8a85
25+
# * `int4_weight_only_hqq` only works on A100
26+
for quant_recipe in float8_rowwise int4_weight_float8_rowwise_activation int4_weight_only_hqq int8_weight_only int8; do
27+
time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS --quant_recipe_name $quant_recipe 2>&1 | tee -a "$LOG_FILE"
28+
done
29+
30+
# TODO(future PR): script to parse the log file instead of manual copy-paste

torchao/quantization/README.md

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,54 @@
11
# Quantization
22
Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. Note: exact APIs are not stable, we may change them in the future.
33

4-
## Benchmarks
5-
Benchmarks and evaluation are gathered using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data on the meta-llama/Meta-Llama-3-8B model.
4+
## Accuracy benchmarks
5+
6+
All the following benchmarks are for `meta-llama/Llama-3-8.1B` using `lm-eval` measured on an H100 GPU.
7+
8+
| Technique | wikitext-perplexity | winogrande | checkpoint size (GB) |
9+
| --------- | ------------------- | ---------- | -------------------- |
10+
| baseline (bfloat16) | 7.3315 | 0.7380 | 16.1 |
11+
| float8_rowwise weight, float8_rowwise activation | 7.4197 | 0.7388 | 9.1 |
12+
| int8_weight_only | 7.3451 | 0.7340 | 9.1 |
13+
| int8 weight, int8 activation | 7.4535 | 0.7285 | 9.1 |
14+
15+
To reproduce, run the following command:
16+
17+
```bash
18+
./benchmarks/quantization/eval_accuracy_for_readme.sh
19+
```
20+
21+
## Performance benchmarks
22+
23+
Benchmarks are gathered using the scripts for [generation](../_models/llama/generate.py).
624

725
### CUDA backend | NVIDIA-A100-80GB GPU
8-
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
9-
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
10-
| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 |
11-
| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 |
12-
| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 |
13-
| | fp6 | 7.661 | 161.58 | 910.02 | 7.72 | 5.63 |
14-
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
15-
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
16-
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
26+
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) |
27+
| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- |
28+
| Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 |
29+
| | int8dq | 8.61 | 64.75 | 9.24 |
30+
| | int8wo | 153.03 | 1150.80 | 10.42 |
31+
| | fp6 | 161.58 | 910.02 | 7.72 |
32+
| | int4wo-64 | 180.80 | 763.33 | 6.88 |
33+
| | int4wo-64-GPTQ | 180.80 | 763.33 | 6.88 |
34+
| | autoquant-int4hqq | 188.41 | 800.58 | 7.14 |
1735

1836
### CUDA backend | NVIDIA-H100 GPU
19-
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
20-
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
21-
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
22-
| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 |
23-
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
24-
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
25-
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
26-
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |
37+
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) |
38+
| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- |
39+
| Llama-3.1-8B | Base (bfloat16) | 126.90 | 1904.75 | 16.75 |
40+
| | int8wo | 198.85 | 1495.41 | 11.05 |
41+
| | int4wo-64 | 241.39 | 1019.14 | 7.08 |
42+
| | float8wo | 178.46 | 1339.93 | 12.09 |
43+
| | float8dq (PerTensor) | 116.40 | 873.58 | 11.14 |
44+
| | float8dq (Per Row) | 154.63 | 1161.47 | 11.14 |
2745

2846
### XPU backend | Intel-Max1100
29-
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
30-
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
31-
| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 |
32-
| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 |
33-
| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52
47+
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) |
48+
| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- |
49+
| Llama-3-8.1B | Base (bfloat16) | 40.36 | 605.77 | 16.35 |
50+
| | int8dq | 13.60 | 102.28 | 18.69 |
51+
| | int8wo | 59.49 | 447.27 | 18.60 |
3452

3553

3654
Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are gathered using [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.

0 commit comments

Comments
 (0)