Skip to content

Commit 5977905

Browse files
authored
add an example for quantizing LLaMa 4 Scout (#3408)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 73730e8 commit 5977905

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed

examples/quantize_llama_4.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
A script demonstrating quantization of the routed experts of
9+
the `meta-llama/Llama-4-Scout-17B-16E-Instruct` model from HuggingFace
10+
to w8a8 with float8 rowwise weights and activations.
11+
"""
12+
13+
import argparse
14+
import random
15+
from pathlib import Path
16+
17+
import fbgemm_gpu
18+
import numpy as np
19+
import torch
20+
import transformers
21+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
22+
23+
from torchao.quantization import (
24+
Float8DynamicActivationFloat8WeightConfig,
25+
FqnToConfig,
26+
PerRow,
27+
)
28+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
29+
Float8Tensor,
30+
)
31+
32+
33+
# Set seeds for reproducibility
34+
def set_seed(seed):
35+
random.seed(seed)
36+
np.random.seed(seed)
37+
torch.manual_seed(seed)
38+
torch.cuda.manual_seed_all(seed)
39+
40+
41+
def get_quantization_config():
42+
expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig(
43+
# the weights of this model are stored in (B, K, N) layout, and we
44+
# need to quantize rowwise across the K axis, which is `PerRow(1)`.
45+
granularity=[PerRow(), PerRow(1)],
46+
# guard against activations with groups of all-zeroes
47+
activation_value_lb=1.0e-12,
48+
)
49+
fqn_to_config = FqnToConfig(
50+
{
51+
# only quantize the routed experts, the rest of the model is left
52+
# in high precision
53+
r"re:.*\.feed_forward\.experts\.gate_up_proj": expert_3d_weight_single_config,
54+
r"re:.*\.feed_forward\.experts\.down_proj": expert_3d_weight_single_config,
55+
}
56+
)
57+
return TorchAoConfig(quant_type=fqn_to_config)
58+
59+
60+
def parse_args():
61+
parser = argparse.ArgumentParser(description="Quantize a model with TorchAO")
62+
parser.add_argument(
63+
"output_dir",
64+
type=str,
65+
help="Directory to save the quantized model",
66+
)
67+
parser.add_argument(
68+
"--max_new_tokens",
69+
type=int,
70+
default=64,
71+
help="Max tokens to generate for testing (default: 64)",
72+
)
73+
parser.add_argument(
74+
"--convert_llama_4_expert_weights_to_mnk",
75+
action="store_true",
76+
help="If set, converts LLaMa 4 Scout expert weights from MKN to MNK memory layout",
77+
)
78+
parser.add_argument(
79+
"--no_save_model_to_disk",
80+
action="store_true",
81+
help="If set, skips saving quantized model to local disk",
82+
)
83+
parser.add_argument(
84+
"--no_load_model_from_disk",
85+
action="store_true",
86+
help="If set, skips reloading model from disk to test it again",
87+
)
88+
return parser.parse_args()
89+
90+
91+
def main(args):
92+
"""
93+
Args:
94+
args: Parsed command line arguments containing:
95+
output_dir: Directory to save the quantized model
96+
max_new_tokens: Max tokens to generate for testing
97+
convert_llama_4_expert_weights_to_mnk: if True, converts LLaMa 4 Scout expert weights from MKN to MNK memory layout
98+
no_save_model_to_disk: if True, skips saving quantized model to local disk
99+
no_load_model_from_disk: if True, skips reloading model from disk to test it again
100+
"""
101+
102+
# ensure relevant dependency versions are satisfied
103+
t_v = str(transformers.__version__)
104+
assert t_v >= "4.58", (
105+
f"transformers version {t_v} too old, please upgrade to a transformers version with https://github.com/huggingface/transformers/pull/41894"
106+
)
107+
f_v = str(fbgemm_gpu.__version__)
108+
if f_v.startswith("202"):
109+
# nightly version, such as '2025.11.22+cu128'
110+
assert f_v >= "2025.11.22", (
111+
f"fbgemm_gpu nightly version {f_v} too old, please upgrade to a nightly from 2025-11-22 or later"
112+
)
113+
else:
114+
# stable version, such as '1.4.1'
115+
assert f_v >= "1.5", (
116+
f"fbgemm_gpu stable version {f_v} too old, please upgrade to 1.5 or later"
117+
)
118+
119+
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
120+
device_map = "auto"
121+
122+
# Test prompts
123+
prompts = [
124+
"Why is Pytorch 2.0 the best machine learning compiler?",
125+
]
126+
127+
# Set seed before creating the model
128+
set_seed(42)
129+
130+
# Create output directory
131+
output_dir = Path(args.output_dir)
132+
output_dir.mkdir(parents=True, exist_ok=True)
133+
134+
# Get quantization config
135+
quantization_config = get_quantization_config()
136+
137+
# Load tokenizer
138+
tokenizer = AutoTokenizer.from_pretrained(model_name)
139+
140+
# Load and quantize model
141+
print("Loading and quantizing model...")
142+
quantized_model = AutoModelForCausalLM.from_pretrained(
143+
model_name,
144+
torch_dtype="bfloat16",
145+
device_map=device_map,
146+
quantization_config=quantization_config,
147+
)
148+
print(quantized_model)
149+
150+
# Test generation
151+
print("\nTesting quantized model generation...")
152+
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to(
153+
quantized_model.device
154+
)
155+
outputs = quantized_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
156+
157+
for i, (prompt, output) in enumerate(zip(prompts, outputs, strict=False)):
158+
generated_text = tokenizer.decode(output, skip_special_tokens=True)
159+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
160+
161+
save_model_to_disk = not args.no_save_model_to_disk
162+
load_model_from_disk = not args.no_load_model_from_disk
163+
164+
if save_model_to_disk:
165+
# Save quantized model
166+
print(f"\nSaving quantized model to: {output_dir}")
167+
168+
if args.convert_llama_4_expert_weights_to_mnk:
169+
print("\nConverting LLaMa 4 expert weights from MKN to MNK layout")
170+
171+
# source: https://github.com/huggingface/transformers/blob/6f6095e0cf509f7384d3ce0c1804013ef6cafd5f/src/transformers/modeling_utils.py#L3466
172+
def save_function(shard, filename):
173+
# `save_pretrained` default logic calls tensor.contiguous() before
174+
# saving, so if we do mkn -> mnk before saving it will be
175+
# converted back to mkn.
176+
# We undo this in the custom save_function, which runs after
177+
# the contiguous call in `save_pretrained`.:)
178+
for k, v in shard.items():
179+
# hacky check for LLaMa 4 experts
180+
if isinstance(v, Float8Tensor) and len(v.shape) == 3:
181+
v.qdata = (
182+
v.qdata.transpose(-2, -1).contiguous().transpose(-2, -1)
183+
)
184+
torch.save(shard, filename)
185+
186+
else:
187+
save_function = torch.save
188+
189+
quantized_model.save_pretrained(
190+
output_dir,
191+
safe_serialization=False,
192+
save_function=save_function,
193+
)
194+
tokenizer.save_pretrained(output_dir)
195+
196+
if load_model_from_disk:
197+
assert save_model_to_disk, "unimplemented"
198+
# Load saved model to verify
199+
# TODO: do we really need `weights_only=False` here?
200+
loaded_model = AutoModelForCausalLM.from_pretrained(
201+
output_dir,
202+
device_map=device_map,
203+
torch_dtype="auto",
204+
weights_only=False,
205+
)
206+
207+
# Test loaded model with first prompt
208+
test_prompt = prompts[0]
209+
input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device)
210+
output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
211+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
212+
print(
213+
f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}"
214+
)
215+
216+
print("\nQuantization process completed successfully.")
217+
218+
219+
if __name__ == "__main__":
220+
args = parse_args()
221+
main(args)

0 commit comments

Comments
 (0)