From 002ba19597c8fb7f78646c8b03aae17ac796574a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 5 Dec 2025 11:58:27 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- .../quantization/eval_accuracy_for_readme.py | 176 ++++++++++++++++++ .../quantization/eval_accuracy_for_readme.sh | 30 +++ torchao/quantization/README.md | 66 ++++--- 3 files changed, 248 insertions(+), 24 deletions(-) create mode 100644 benchmarks/quantization/eval_accuracy_for_readme.py create mode 100755 benchmarks/quantization/eval_accuracy_for_readme.sh diff --git a/benchmarks/quantization/eval_accuracy_for_readme.py b/benchmarks/quantization/eval_accuracy_for_readme.py new file mode 100644 index 0000000000..abf38e1bdb --- /dev/null +++ b/benchmarks/quantization/eval_accuracy_for_readme.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import subprocess + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + PerRow, +) + + +def string_to_config(s): + if s is None: + return None + elif s == "float8_rowwise": + return Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + elif s == "int4_weight_float8_rowwise_activation": + return Float8DynamicActivationInt4WeightConfig() + elif s == "int4_weight_only_hqq": + return Int4WeightOnlyConfig( + group_size=32, + int4_packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", + ) + elif s == "int8_weight_only": + return Int8WeightOnlyConfig() + elif s == "int8": + return Int8DynamicActivationInt8WeightConfig() + else: + raise AssertionError(f"unsupported {s}") + + +def quantize_model_and_save(model_id, quant_config, output_dir="results"): + """Quantize the model and save it to the output directory.""" + print("Quantizing model with config: ", quant_config) + if quant_config is None: + quantization_config = None + else: + quantization_config = TorchAoConfig(quant_type=quant_config) + quantized_model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + dtype=torch.bfloat16, + quantization_config=quantization_config, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + quantized_model.save_pretrained(output_dir, safe_serialization=False) + tokenizer.save_pretrained(output_dir, safe_serialization=False) + return quantized_model, tokenizer + + +def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8): + """Run the lm_eval command using subprocess.""" + tasks_str = ",".join(tasks_list) + command = [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={model_dir}", + "--tasks", + f"{tasks_str}", + "--device", + f"{device}", + "--batch_size", + f"{batch_size}", + "--output_path", + f"{model_dir}/lm_eval_outputs/", + ] + subprocess.run(command, check=True) + + +def get_size_of_dir(model_output_dir): + # get dir size from shell, to skip complexity of dealing with tensor + # subclasses + result = subprocess.run( + ["du", "-sb", model_output_dir], capture_output=True, text=True + ) + size = int(result.stdout.split()[0]) + return size + + +def run( + model_id: str, + quant_recipe_name: str | None, + tasks, + device, + batch_size, + model_output_dir, +): + print(f"\nRunning {model_id=} with {quant_recipe_name=}\n") + model_name = model_id.split("/")[-1] + model_output_dir = ( + f"benchmarks/data/quantized_model/{model_name}-{quant_recipe_name}" + ) + quant_config = string_to_config(quant_recipe_name) + quantized_model, tokenizer = quantize_model_and_save( + model_id, quant_config=quant_config, output_dir=model_output_dir + ) + print(quantized_model) + + model_size = get_size_of_dir(model_output_dir) / 1e9 + print(f"checkpoint size: {model_size} GB") + + run_lm_eval( + model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size + ) + print("done\n") + + +if __name__ == "__main__": + try: + import lm_eval # noqa: F401 + except: + print( + "lm_eval is required to run this script. Please install it using pip install lm-eval." + ) + exit(0) + + # Set up argument parser + parser = argparse.ArgumentParser( + description="Quantize a model and evaluate its throughput." + ) + parser.add_argument( + "--model_id", + type=str, + default="meta-llama/Llama-3.1-8B", + help="The model ID to use.", + ) + parser.add_argument( + "--quant_recipe_name", + type=str, + default=None, + help="The quantization recipe to use.", + ) + parser.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", + ) + parser.add_argument( + "--device", type=str, default="cuda:0", help="Device to run the model on." + ) + parser.add_argument( + "--batch_size", type=str, default="auto", help="Batch size for lm_eval." + ) + parser.add_argument( + "--output_dir", + type=str, + default="quantized_models", + help="Output directory for quantized model.", + ) + args = parser.parse_args() + + # Use parsed arguments + run( + model_id=args.model_id, + quant_recipe_name=args.quant_recipe_name, + tasks=args.tasks, + device=args.device, + batch_size=args.batch_size, + model_output_dir=args.output_dir, + ) diff --git a/benchmarks/quantization/eval_accuracy_for_readme.sh b/benchmarks/quantization/eval_accuracy_for_readme.sh new file mode 100755 index 0000000000..822371f3ca --- /dev/null +++ b/benchmarks/quantization/eval_accuracy_for_readme.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +set -e + +# Get model_id as positional argument (optional) +MODEL_ID="${1:-meta-llama/Llama-3.1-8B}" + +# Get log file as first positional argument (optional) +LOG_FILE="${2:-benchmarks/data/eval_accuracy_for_readme_log.txt}" + +# Build the base command arguments +BASE_ARGS="--tasks wikitext winogrande" +if [[ -n "$MODEL_ID" ]]; then + BASE_ARGS="--model_id $MODEL_ID $BASE_ARGS" +fi + +# baseline +# note: the -u flag is to prevent python from buffering stdout and stderr +# and make the output log file be in chronological order +time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS 2>&1 | tee "$LOG_FILE" + +# quantized recipes +# note: +# * `int4_weight_float8_rowwise_activation` doesn't work with dtype_map auto: https://gist.github.com/vkuzo/6b128681b628744d445c553cdeac8a85 +# * `int4_weight_only_hqq` only works on A100 +for quant_recipe in float8_rowwise int4_weight_float8_rowwise_activation int4_weight_only_hqq int8_weight_only int8; do + time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS --quant_recipe_name $quant_recipe 2>&1 | tee -a "$LOG_FILE" +done + +# TODO(future PR): script to parse the log file instead of manual copy-paste diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index a4b5d2801e..ec09bd8fa6 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -1,36 +1,54 @@ # Quantization Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. Note: exact APIs are not stable, we may change them in the future. -## Benchmarks -Benchmarks and evaluation are gathered using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data on the meta-llama/Meta-Llama-3-8B model. +## Accuracy benchmarks + +All the following benchmarks are for `meta-llama/Llama-3-8.1B` using `lm-eval` measured on an H100 GPU. + +| Technique | wikitext-perplexity | winogrande | checkpoint size (GB) | +| --------- | ------------------- | ---------- | -------------------- | +| baseline (bfloat16) | 7.3315 | 0.7380 | 16.1 | +| float8_rowwise weight, float8_rowwise activation | 7.4197 | 0.7388 | 9.1 | +| int8_weight_only | 7.3451 | 0.7340 | 9.1 | +| int8 weight, int8 activation | 7.4535 | 0.7285 | 9.1 | + +To reproduce, run the following command: + +```bash +./benchmarks/quantization/eval_accuracy_for_readme.sh +``` + +## Performance benchmarks + +Benchmarks are gathered using the scripts for [generation](../_models/llama/generate.py). ### CUDA backend | NVIDIA-A100-80GB GPU -| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | -| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | -| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | -| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | -| | fp6 | 7.661 | 161.58 | 910.02 | 7.72 | 5.63 | -| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | -| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | -| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | +| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- | +| Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 | +| | int8dq | 8.61 | 64.75 | 9.24 | +| | int8wo | 153.03 | 1150.80 | 10.42 | +| | fp6 | 161.58 | 910.02 | 7.72 | +| | int4wo-64 | 180.80 | 763.33 | 6.88 | +| | int4wo-64-GPTQ | 180.80 | 763.33 | 6.88 | +| | autoquant-int4hqq | 188.41 | 800.58 | 7.14 | ### CUDA backend | NVIDIA-H100 GPU -| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | -| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 | -| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 | -| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 | -| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 | -| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 | -| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 | +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | +| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- | +| Llama-3.1-8B | Base (bfloat16) | 126.90 | 1904.75 | 16.75 | +| | int8wo | 198.85 | 1495.41 | 11.05 | +| | int4wo-64 | 241.39 | 1019.14 | 7.08 | +| | float8wo | 178.46 | 1339.93 | 12.09 | +| | float8dq (PerTensor) | 116.40 | 873.58 | 11.14 | +| | float8dq (Per Row) | 154.63 | 1161.47 | 11.14 | ### XPU backend | Intel-Max1100 -| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | -| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 | -| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 | -| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52 +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | +| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- | +| Llama-3-8.1B | Base (bfloat16) | 40.36 | 605.77 | 16.35 | +| | int8dq | 13.60 | 102.28 | 18.69 | +| | int8wo | 59.49 | 447.27 | 18.60 | Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are gathered using [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.