diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh index d71d16e422..472dbac02a 100644 --- a/benchmarks/_models/eval_hf_models.sh +++ b/benchmarks/_models/eval_hf_models.sh @@ -4,26 +4,28 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +# This script uses the unified llm_eval.py for LLM evaluation +# For full options, run: python -m benchmarks._models.llm_eval --help # For llama3.1-8B -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-4 --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-8 --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization float8wo --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-4-128 --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-8 --tasks wikitext hellaswag # For llama3.2-3B -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-4 --tasks wikitext hellaswag -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-8 --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization float8wo --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-4-128 --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-8 --tasks wikitext hellaswag diff --git a/benchmarks/_models/llm_eval.py b/benchmarks/_models/llm_eval.py new file mode 100644 index 0000000000..f89d75d489 --- /dev/null +++ b/benchmarks/_models/llm_eval.py @@ -0,0 +1,608 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unified LLM Evaluation Script + +This script provides a unified interface for evaluating language models with various +quantization methods. It supports both: +- gpt-fast format (.pth checkpoints) +- HuggingFace format (model IDs or local paths) + +Usage: + # gpt-fast checkpoint + python -m benchmarks._models.llm_eval \ + --checkpoint_path /path/to/model.pth \ + --quantization int4wo-128 \ + --tasks wikitext + + # HuggingFace model + python -m benchmarks._models.llm_eval \ + --model_id meta-llama/Llama-3.1-8B \ + --quantization int8wo \ + --tasks wikitext hellaswag + + # Auto-detect format + python -m benchmarks._models.llm_eval \ + --model meta-llama/Llama-3.1-8B \ + --quantization int4wo-128 \ + --tasks wikitext +""" + +import argparse +import itertools +import time +from pathlib import Path +from typing import List, Optional, Union + +import torch + +import torchao +from benchmarks.microbenchmarks.utils import apply_quantization + + +# ============================================================================= +# Model Loading +# ============================================================================= + + +def _is_hf_model_id(model_path: str) -> bool: + """Check if the model path is a HuggingFace model ID.""" + # HF model IDs typically have format: org/model-name or just model-name + # .pth files are gpt-fast format + if model_path.endswith(".pth"): + return False + path = Path(model_path) + # If it's a directory with model files, treat as local HF model + if path.is_dir(): + return (path / "config.json").exists() or (path / "model.safetensors").exists() + # If it contains a slash and doesn't exist as a file, assume HF model ID + if "/" in model_path and not path.exists(): + return True + # If it's a .pth file path (even if doesn't exist yet), it's gpt-fast + return not str(model_path).endswith(".pth") + + +def load_model_gptfast( + checkpoint_path: Path, + device: str = "cpu", + precision: torch.dtype = torch.bfloat16, +): + """Load a gpt-fast format model (.pth checkpoint). + + Args: + checkpoint_path: Path to the .pth checkpoint file + device: Device to load model to + precision: Model precision (dtype) + + Returns: + Tuple of (model, tokenizer, input_prep_func) + """ + # Import from the llama module - these are relative imports when running from the repo + import sys + + # Add the llama directory to path for relative imports + llama_dir = Path(__file__).parent.parent.parent / "torchao" / "_models" / "llama" + if str(llama_dir) not in sys.path: + sys.path.insert(0, str(llama_dir)) + + from torchao._models.llama.generate import _load_model, device_sync + from torchao._models.llama.model import prepare_inputs_for_model + + # Also need to get the tokenizer + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert checkpoint_path.is_file(), f"Checkpoint not found: {checkpoint_path}" + assert tokenizer_path.is_file(), f"Tokenizer not found: {tokenizer_path}" + + print(f"Loading gpt-fast model from {checkpoint_path}...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + device_sync(device=device) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + # Load tokenizer + from torchao._models.llama.tokenizer import get_tokenizer + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + return model, tokenizer, prepare_inputs_for_model + + +def load_model_huggingface( + model_id: str, + device: str = "cuda", + precision: torch.dtype = torch.bfloat16, +): + """Load a HuggingFace format model. + + Args: + model_id: HuggingFace model ID or local path + device: Device to load model to + precision: Model precision (dtype) + + Returns: + Tuple of (model, tokenizer, input_prep_func) + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + print(f"Loading HuggingFace model: {model_id}...") + t0 = time.time() + + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=precision, + device_map=device if device != "cpu" else None, + ) + if device == "cpu": + model = model.to(device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + # HuggingFace models don't need input preparation + input_prep_func = None + + return model, tokenizer, input_prep_func + + +def load_model( + model: str, + checkpoint_path: Optional[Path] = None, + model_id: Optional[str] = None, + device: str = "cuda", + precision: torch.dtype = torch.bfloat16, +): + """Load a model, auto-detecting the format. + + Args: + model: Model path or ID (used for auto-detection) + checkpoint_path: Explicit gpt-fast checkpoint path + model_id: Explicit HuggingFace model ID + device: Device to load model to + precision: Model precision + + Returns: + Tuple of (model, tokenizer, input_prep_func, model_format) + """ + # Explicit checkpoint path takes precedence + if checkpoint_path is not None: + model, tokenizer, input_prep_func = load_model_gptfast( + checkpoint_path, device="cpu", precision=precision + ) + return model, tokenizer, input_prep_func, "gptfast" + + # Explicit model_id takes precedence + if model_id is not None: + model, tokenizer, input_prep_func = load_model_huggingface( + model_id, device=device, precision=precision + ) + return model, tokenizer, input_prep_func, "huggingface" + + # Auto-detect from model argument + if model is not None: + if _is_hf_model_id(model): + model_obj, tokenizer, input_prep_func = load_model_huggingface( + model, device=device, precision=precision + ) + return model_obj, tokenizer, input_prep_func, "huggingface" + else: + model_obj, tokenizer, input_prep_func = load_model_gptfast( + Path(model), device="cpu", precision=precision + ) + return model_obj, tokenizer, input_prep_func, "gptfast" + + raise ValueError("Must provide either --model, --checkpoint_path, or --model_id") + + +# ============================================================================= +# Model Size Calculation +# ============================================================================= + + +def get_model_size_in_bytes(model: torch.nn.Module, ignore_embeddings: bool = False) -> int: + """Calculate model size in bytes, handling quantized tensors. + + Args: + model: The model to measure + ignore_embeddings: Whether to ignore embedding layers + + Returns: + Model size in bytes + """ + + def flat_size(tensor): + if hasattr(tensor, "__tensor_flatten__"): + size = 0 + for attr_name in tensor.__tensor_flatten__()[0]: + sub_tensor = getattr(tensor, attr_name) + size += flat_size(sub_tensor) + return size + else: + return tensor.numel() * tensor.element_size() + + model_size = 0 + for _, child in model.named_children(): + if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): + for p in itertools.chain( + child.parameters(recurse=False), child.buffers(recurse=False) + ): + model_size += flat_size(p) + model_size += get_model_size_in_bytes(child, ignore_embeddings) + return model_size + + +# ============================================================================= +# Evaluation +# ============================================================================= + + +def run_evaluation( + # Model specification (one of these required) + model: Optional[str] = None, + checkpoint_path: Optional[Path] = None, + model_id: Optional[str] = None, + # Evaluation parameters + tasks: List[str] = None, + limit: Optional[int] = None, + # Model configuration + device: str = "cuda", + precision: torch.dtype = torch.bfloat16, + # Quantization + quantization: Optional[str] = None, + sparsity: Optional[str] = None, + # Compilation + compile: bool = False, + compile_mode: str = "max-autotune", + # Sequence length + max_length: Optional[int] = None, + # Calibration (for GPTQ, AWQ, etc.) + calibration_tasks: Optional[List[str]] = None, + calibration_limit: Optional[int] = None, + calibration_seq_length: Optional[int] = None, + pad_calibration_inputs: bool = False, + # Output + print_model: bool = False, + output_dir: Optional[str] = None, +): + """Run LLM evaluation with optional quantization. + + This is the main entry point that supports both gpt-fast and HuggingFace models. + + Args: + model: Model path or HuggingFace model ID (auto-detected) + checkpoint_path: Explicit gpt-fast checkpoint path (.pth) + model_id: Explicit HuggingFace model ID + tasks: List of lm-eval tasks to run + limit: Number of samples to evaluate (None = all) + device: Device to run on + precision: Model precision (torch.bfloat16, torch.float16, etc.) + quantization: Quantization method string + sparsity: Sparsity method string + compile: Whether to torch.compile the model + compile_mode: Compilation mode (max-autotune, reduce-overhead, etc.) + max_length: Maximum sequence length for evaluation + calibration_tasks: Tasks for calibration (GPTQ, AWQ, etc.) + calibration_limit: Number of calibration samples + calibration_seq_length: Sequence length for calibration + pad_calibration_inputs: Whether to pad short calibration sequences + print_model: Whether to print the model architecture + output_dir: Directory to save quantized model (HF models only) + + Returns: + Evaluation results dictionary + """ + if tasks is None: + tasks = ["wikitext"] + + print( + f"\n{'='*60}\n" + f"LLM Evaluation\n" + f"{'='*60}\n" + f"Model: {model or checkpoint_path or model_id}\n" + f"Tasks: {tasks}\n" + f"Limit: {limit}\n" + f"Device: {device}\n" + f"Precision: {precision}\n" + f"Quantization: {quantization}\n" + f"Sparsity: {sparsity}\n" + f"Compile: {compile}\n" + f"{'='*60}\n" + ) + + # Set recommended inductor config + torchao.quantization.utils.recommended_inductor_config_setter() + + # Load model + model_obj, tokenizer, input_prep_func, model_format = load_model( + model=model, + checkpoint_path=checkpoint_path, + model_id=model_id, + device=device, + precision=precision, + ) + + # Set max_length from model config if not provided + if max_length is None: + if hasattr(model_obj, "config"): + if hasattr(model_obj.config, "block_size"): + max_length = model_obj.config.block_size + elif hasattr(model_obj.config, "max_position_embeddings"): + max_length = model_obj.config.max_position_embeddings + else: + max_length = 2048 + else: + max_length = 2048 + + # Apply quantization using the unified apply_quantization function + if quantization or sparsity: + print(f"Applying quantization: {quantization}, sparsity: {sparsity}") + model_obj = apply_quantization( + model=model_obj, + quantization=quantization, + sparsity=sparsity, + tokenizer=tokenizer, + calibration_tasks=calibration_tasks or ["wikitext"], + calibration_limit=calibration_limit, + calibration_seq_length=calibration_seq_length, + pad_calibration_inputs=pad_calibration_inputs, + device=device, + input_prep_func=input_prep_func, + max_seq_length=max_length, + ) + + # Compile model if requested + if compile: + print(f"Compiling model with mode: {compile_mode}") + if quantization == "float8_a1x128_w128x128": + model_obj = torch.compile(model_obj) + else: + model_obj = torch.compile(model_obj, mode=compile_mode, fullgraph=True) + + # Print model if requested + if print_model: + print(model_obj) + + # Calculate and print model size + model_size = get_model_size_in_bytes(model_obj, ignore_embeddings=True) / 1e9 + print(f"Model size: {model_size:.2f} GB") + + # Move model to device + model_obj = model_obj.to(device) + + # Run evaluation + print("\nRunning evaluation...") + with torch.no_grad(): + from torchao._models._eval import TransformerEvalWrapper + + wrapper = TransformerEvalWrapper( + model=model_obj, + tokenizer=tokenizer, + max_seq_length=max_length, + input_prep_func=input_prep_func, + device=device, + ) + result = wrapper.run_eval(tasks=tasks, limit=limit) + + return result + + +# ============================================================================= +# CLI +# ============================================================================= + + +def main(): + """Main entry point for CLI.""" + try: + import lm_eval # noqa: F401 + except ImportError: + print( + "lm_eval is required to run this script. " + "Please install it using: pip install lm-eval" + ) + return + + parser = argparse.ArgumentParser( + description="Unified LLM Evaluation Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Evaluate gpt-fast checkpoint with int4 quantization + python -m benchmarks._models.llm_eval \\ + --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth \\ + --quantization int4wo-128 \\ + --tasks wikitext + + # Evaluate HuggingFace model + python -m benchmarks._models.llm_eval \\ + --model_id meta-llama/Llama-3.1-8B \\ + --quantization int8wo \\ + --tasks wikitext hellaswag + + # GPTQ quantization with calibration + python -m benchmarks._models.llm_eval \\ + --checkpoint_path checkpoints/model.pth \\ + --quantization int4wo-128-gptq \\ + --calibration_tasks wikitext \\ + --calibration_limit 128 \\ + --calibration_seq_length 2048 + +Supported quantization methods: + - int8wo, int8dq: INT8 weight-only / dynamic quantization + - int4wo-: INT4 weight-only (e.g., int4wo-128) + - int4wo--hqq: INT4 with HQQ + - int4wo--gptq: INT4 with GPTQ calibration + - float8wo, float8dq-tensor, float8dq-row: FP8 quantization + - uintx--: UIntX quantization (e.g., uintx-4-64) + - marlin: Marlin sparse layout + - spinquant: SpinQuant preprocessing (combinable, e.g., spinquant-int4wo-128) + - autoround: AutoRound quantization + - awq-uintx--: AWQ quantization + - codebook: Codebook quantization + """, + ) + + # Model specification (mutually exclusive group) + model_group = parser.add_argument_group("Model Specification") + model_group.add_argument( + "--model", + type=str, + default=None, + help="Model path or HuggingFace model ID (auto-detected)", + ) + model_group.add_argument( + "--checkpoint_path", + type=Path, + default=None, + help="Explicit gpt-fast checkpoint path (.pth file)", + ) + model_group.add_argument( + "--model_id", + type=str, + default=None, + help="Explicit HuggingFace model ID", + ) + + # Evaluation parameters + eval_group = parser.add_argument_group("Evaluation") + eval_group.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="List of lm-eval tasks (default: wikitext)", + ) + eval_group.add_argument( + "--limit", + type=int, + default=None, + help="Number of samples to evaluate (default: all)", + ) + eval_group.add_argument( + "--max_length", + type=int, + default=None, + help="Maximum sequence length for evaluation", + ) + + # Device and precision + device_group = parser.add_argument_group("Device & Precision") + device_group.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run on (default: cuda)", + ) + device_group.add_argument( + "--precision", + type=lambda x: getattr(torch, x.split(".")[-1]), + default=torch.bfloat16, + help="Model precision (default: bfloat16)", + ) + + # Quantization + quant_group = parser.add_argument_group("Quantization") + quant_group.add_argument( + "-q", + "--quantization", + type=str, + default=None, + help="Quantization method (see examples below)", + ) + quant_group.add_argument( + "--sparsity", + type=str, + default=None, + help="Sparsity method (semi, 2:4, block)", + ) + + # Compilation + compile_group = parser.add_argument_group("Compilation") + compile_group.add_argument( + "--compile", + action="store_true", + help="Enable torch.compile", + ) + compile_group.add_argument( + "--compile_mode", + type=str, + default="max-autotune", + choices=["default", "reduce-overhead", "max-autotune"], + help="torch.compile mode (default: max-autotune)", + ) + + # Calibration (for GPTQ, AWQ, etc.) + calib_group = parser.add_argument_group("Calibration (for GPTQ, AWQ, AutoRound)") + calib_group.add_argument( + "--calibration_tasks", + nargs="+", + type=str, + default=["wikitext"], + help="Tasks for calibration data (default: wikitext)", + ) + calib_group.add_argument( + "--calibration_limit", + type=int, + default=128, + help="Number of calibration samples (default: 128)", + ) + calib_group.add_argument( + "--calibration_seq_length", + type=int, + default=2048, + help="Sequence length for calibration (default: 2048)", + ) + calib_group.add_argument( + "--pad_calibration_inputs", + action="store_true", + help="Pad short calibration sequences", + ) + + # Output + output_group = parser.add_argument_group("Output") + output_group.add_argument( + "--print_model", + action="store_true", + help="Print model architecture", + ) + output_group.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save quantized model", + ) + + args = parser.parse_args() + + # Validate that at least one model source is provided + if args.model is None and args.checkpoint_path is None and args.model_id is None: + parser.error("Must provide one of: --model, --checkpoint_path, or --model_id") + + run_evaluation( + model=args.model, + checkpoint_path=args.checkpoint_path, + model_id=args.model_id, + tasks=args.tasks, + limit=args.limit, + device=args.device, + precision=args.precision, + quantization=args.quantization, + sparsity=args.sparsity, + compile=args.compile, + compile_mode=args.compile_mode, + max_length=args.max_length, + calibration_tasks=args.calibration_tasks, + calibration_limit=args.calibration_limit, + calibration_seq_length=args.calibration_seq_length, + pad_calibration_inputs=args.pad_calibration_inputs, + print_model=args.print_model, + output_dir=args.output_dir, + ) + + +if __name__ == "__main__": + main() + diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 2c6a443a86..54e00d2b24 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -281,6 +281,14 @@ def string_to_config( ) elif "float8wo" in quantization: return Float8WeightOnlyConfig() + elif quantization == "float8_a1x128_w128x128": + # Blockwise float8 quantization with 1x128 activation and 128x128 weight blocks + from torchao.quantization import PerBlock + + return Float8DynamicActivationFloat8WeightConfig( + granularity=(PerBlock([1, 128]), PerBlock([128, 128])), + activation_value_lb=1e-12, + ) elif "float8dq" in quantization: if sparsity and "semi" in sparsity: return Float8DynamicActivationFloat8SemiSparseWeightConfig() @@ -309,9 +317,425 @@ def string_to_config( 256, ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" return GemliteUIntXWeightOnlyConfig(group_size=group_size, bit_width=bit_width) + if "codebook" in quantization: + # Codebook quantization (prototype) + # Format: codebook or codebook- + from torchao.prototype.quantization.codebook import codebook_weight_only + + params = quantization.split("-") + scale_block_size = int(params[1]) if len(params) > 1 else 64 + return codebook_weight_only(dtype=torch.uint4, scale_block_size=scale_block_size) return None +def _requires_calibration(quantization: Optional[str]) -> bool: + """Check if the quantization method requires calibration data.""" + if quantization is None: + return False + calibration_methods = ["gptq", "autoround", "awq-uintx", "smoothquant"] + return any(method in quantization.lower() for method in calibration_methods) + + +def _apply_spinquant(model: torch.nn.Module) -> torch.nn.Module: + """Apply SpinQuant rotation transform to the model. + + SpinQuant applies a rotation transform before quantization to reduce + quantization error. This is a preprocessing step, not quantization itself. + + Args: + model: The model to apply SpinQuant to + + Returns: + The model with SpinQuant applied + """ + from torchao.prototype.spinquant import apply_spinquant + + apply_spinquant(model) + return model + + +def _apply_gptq( + model: torch.nn.Module, + tokenizer, + quantization: str, + calibration_tasks: List[str], + calibration_limit: int, + calibration_seq_length: int, + pad_calibration_inputs: bool, + device: str, + input_prep_func=None, +) -> torch.nn.Module: + """Apply GPTQ quantization with calibration. + + Format: int4wo--gptq + + Args: + model: The model to quantize + tokenizer: Tokenizer for encoding calibration data + quantization: Quantization string (e.g., "int4wo-128-gptq") + calibration_tasks: Tasks to use for calibration (e.g., ["wikitext"]) + calibration_limit: Number of calibration samples + calibration_seq_length: Sequence length for calibration + pad_calibration_inputs: Whether to pad short sequences + device: Device to run calibration on + input_prep_func: Optional function to prepare inputs for the model + + Returns: + The quantized model + """ + from torchao._models._eval import LMEvalInputRecorder + from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer + + # Parse group size from quantization string: int4wo--gptq + parts = quantization.split("-") + groupsize = int(parts[1]) + assert groupsize in [32, 64, 128, 256], ( + f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" + ) + + # Default input prep function if not provided + if input_prep_func is None: + input_prep_func = lambda x: (x,) + + # Get vocab size from model config + vocab_size = getattr(model.config, "vocab_size", 32000) + + # Record calibration inputs + inputs = ( + LMEvalInputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + vocab_size, + pad_calibration_inputs, + device="cpu", + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_recorded_inputs() + ) + print("Obtained calibration inputs, starting GPTQ quantization") + + # Setup caches if model supports it + if hasattr(model, "setup_caches"): + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + + # Quantize with GPTQ + quantizer = Int4WeightOnlyGPTQQuantizer(group_size=groupsize, device=device) + quantizer.quantize(model, *inputs) + model = model.to(device) + + return model + + +def _apply_autoround( + model: torch.nn.Module, + tokenizer, + quantization: str, + device: str, + quant_lm_head: bool = False, +) -> torch.nn.Module: + """Apply AutoRound quantization with calibration. + + Format: autoround or autoround--------- + + Args: + model: The model to quantize + tokenizer: Tokenizer for encoding calibration data + quantization: Quantization string with optional parameters + device: Device to run calibration on + quant_lm_head: Whether to quantize the lm_head layer + + Returns: + The quantized model + """ + from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ + + # Parse args from quantization string + _quant_args = quantization.split("-") + _default_quant_args = [False, 200, 128, 8, 2048, 128, 1, 0] + _model_device = _quant_args[1] if len(_quant_args) > 1 else device + _quant_args = _quant_args[2:] + + ( + quant_lm_head_arg, + iters, + groupsize, + batch_size, + seqlen, + nsamples, + grad_acc_steps, + compile_optimization_process, + ) = [int(x) for x in _quant_args] + _default_quant_args[len(_quant_args) :] + + # Override quant_lm_head if explicitly passed + quant_lm_head = quant_lm_head or bool(quant_lm_head_arg) + + model = model.to(_model_device) + print( + f"Quantizing model with AutoRound(iters={iters}, groupsize={groupsize}, " + f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, " + f"nsamples={nsamples}, gradient_accumulate_steps={grad_acc_steps}, " + f"compile_optimization_process={compile_optimization_process})" + ) + + # Setup caches if model supports it + if hasattr(model, "setup_caches"): + with torch.device(_model_device): + model.setup_caches(max_batch_size=batch_size, max_seq_length=seqlen, training=True) + + # Determine target modules based on model architecture + # Try to find the decoder block class dynamically + decoder_cls = None + for name, module in model.named_modules(): + cls_name = module.__class__.__name__ + if "Block" in cls_name or "Layer" in cls_name: + decoder_cls = type(module) + break + + if decoder_cls is not None: + if quant_lm_head: + is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn or "output" in fqn + else: + is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) + else: + # Fallback: quantize all Linear layers except embeddings + if quant_lm_head: + is_target_module = lambda mod, fqn: isinstance(mod, torch.nn.Linear) + else: + is_target_module = lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "lm_head" not in fqn + + quantize_model_with_autoround_( + model=model, + tokenizer=tokenizer, + is_target_module=is_target_module, + bits=4, + seqlen=seqlen, + batch_size=batch_size, + iters=iters, + nsamples=nsamples, + gradient_accumulate_steps=grad_acc_steps, + compile_optimization_process=compile_optimization_process == 1, + ) + + model.to(device) + if hasattr(model, "reset_caches"): + model.reset_caches() + + return model + + +def _apply_awq( + model: torch.nn.Module, + tokenizer, + quantization: str, + eval_wrapper_cls, + max_seq_length: int, + device: str, + input_prep_func=None, +) -> torch.nn.Module: + """Apply AWQ quantization with calibration. + + Format: awq-uintx-- or awq-uintx---hqq + + Args: + model: The model to quantize + tokenizer: Tokenizer for encoding calibration data + quantization: Quantization string (e.g., "awq-uintx-uint4-64") + eval_wrapper_cls: The evaluation wrapper class to use for calibration + max_seq_length: Maximum sequence length for calibration + device: Device to run calibration on + input_prep_func: Optional function to prepare inputs for the model + + Returns: + The quantized model + """ + from torchao.prototype.awq import ( + AWQObservedLinear, + awq_uintx, + insert_awq_observer_, + ) + from torchao.quantization import quantize_ + + # Parse quantization string: awq-uintx--[-hqq] + parts = quantization.split("-") + quant_dtype_str = parts[2] if len(parts) > 2 else "uint4" + group_size = int(parts[3]) if len(parts) > 3 else 64 + use_hqq = "hqq" in quantization + + quant_dtype = getattr(torch, quant_dtype_str, torch.uint4) + + model = model.to(device) + + # Insert AWQ observers + insert_awq_observer_(model, 1, 256, quant_dtype=quant_dtype, group_size=group_size) + + # Run calibration + eval_wrapper_cls( + model=model, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + input_prep_func=input_prep_func, + device=device, + ).run_eval( + tasks=["wikitext"], + limit=1, + ) + + # Convert observed model to quantized model + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + quantize_( + model, + awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq), + is_observed_linear, + ) + + return model + + +def apply_quantization( + model: torch.nn.Module, + quantization: Optional[str], + sparsity: Optional[str] = None, + tokenizer=None, + calibration_tasks: Optional[List[str]] = None, + calibration_limit: Optional[int] = None, + calibration_seq_length: Optional[int] = None, + pad_calibration_inputs: bool = False, + device: str = "cuda", + input_prep_func=None, + eval_wrapper_cls=None, + max_seq_length: int = 2048, + **kwargs, +) -> torch.nn.Module: + """Apply quantization to a model, handling both simple configs and calibration-based methods. + + This is the main entry point for applying quantization. It handles: + - Simple config-based quantization (int8wo, int4wo, float8, etc.) + - Calibration-based quantization (GPTQ, AWQ, AutoRound) + - Preprocessing transforms (SpinQuant) + - Sparsity (semi-structured, block sparse) + + Args: + model: The model to quantize + quantization: Quantization method string. Supported formats: + - Simple: "int8wo", "int8dq", "int4wo-128", "float8wo", "float8dq-row", etc. + - GPTQ: "int4wo-128-gptq" (requires tokenizer and calibration params) + - AWQ: "awq-uintx-uint4-64" (requires tokenizer and eval_wrapper_cls) + - AutoRound: "autoround" or "autoround-cuda-1-200-128-8-2048-128-1-0" + - SpinQuant: "spinquant" (preprocessing, can be combined with other methods) + sparsity: Sparsity method ("semi", "2:4", "block") + tokenizer: Tokenizer for calibration-based methods + calibration_tasks: Tasks for calibration (e.g., ["wikitext"]) + calibration_limit: Number of calibration samples + calibration_seq_length: Sequence length for calibration + pad_calibration_inputs: Whether to pad short calibration sequences + device: Device to run on + input_prep_func: Function to prepare inputs for the model + eval_wrapper_cls: Evaluation wrapper class for AWQ calibration + max_seq_length: Maximum sequence length for AWQ calibration + **kwargs: Additional arguments passed to string_to_config + + Returns: + The quantized model + + Example: + >>> # Simple quantization + >>> model = apply_quantization(model, "int4wo-128") + + >>> # GPTQ quantization + >>> model = apply_quantization( + ... model, "int4wo-128-gptq", + ... tokenizer=tokenizer, + ... calibration_tasks=["wikitext"], + ... calibration_limit=128, + ... calibration_seq_length=2048, + ... ) + + >>> # SpinQuant + int4 + >>> model = apply_quantization(model, "spinquant-int4wo-128", tokenizer=tokenizer) + """ + from torchao.quantization import quantize_ + from torchao.sparsity.sparse_api import sparsify_ + + if quantization is None and sparsity is None: + return model + + # Handle SpinQuant preprocessing (can be combined with other quantization) + if quantization and "spinquant" in quantization: + model = _apply_spinquant(model) + # Remove spinquant from quantization string for further processing + quantization = quantization.replace("spinquant-", "").replace("spinquant", "") + if not quantization: + quantization = None + + # Handle calibration-based methods + if quantization and "gptq" in quantization: + if tokenizer is None: + raise ValueError("GPTQ quantization requires a tokenizer") + if calibration_tasks is None: + calibration_tasks = ["wikitext"] + if calibration_limit is None: + calibration_limit = 128 + if calibration_seq_length is None: + calibration_seq_length = 2048 + + return _apply_gptq( + model=model, + tokenizer=tokenizer, + quantization=quantization, + calibration_tasks=calibration_tasks, + calibration_limit=calibration_limit, + calibration_seq_length=calibration_seq_length, + pad_calibration_inputs=pad_calibration_inputs, + device=device, + input_prep_func=input_prep_func, + ) + + if quantization and "autoround" in quantization: + if tokenizer is None: + raise ValueError("AutoRound quantization requires a tokenizer") + + return _apply_autoround( + model=model, + tokenizer=tokenizer, + quantization=quantization, + device=device, + ) + + if quantization and quantization.startswith("awq-uintx"): + if tokenizer is None: + raise ValueError("AWQ quantization requires a tokenizer") + if eval_wrapper_cls is None: + from torchao._models._eval import TransformerEvalWrapper + eval_wrapper_cls = TransformerEvalWrapper + + return _apply_awq( + model=model, + tokenizer=tokenizer, + quantization=quantization, + eval_wrapper_cls=eval_wrapper_cls, + max_seq_length=max_seq_length, + device=device, + input_prep_func=input_prep_func, + ) + + # Handle simple config-based quantization and sparsity + config = string_to_config(quantization, sparsity, **kwargs) + + if config is not None: + # Check if it's a sparsity-only config + if quantization is None and sparsity is not None: + sparsify_(model, config) + else: + model = model.to(device) + quantize_(model, config) + + return model + + @torch.no_grad() def model_inference_time_in_ms(model, input_data): """Benchmark model inference time without compile overhead. diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index 353ba754ca..1d40ecac5f 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -153,7 +153,7 @@ Note: llama model (llama2/llama3) is our representative model for memory bound m * `llama `__ * `benchmark `__ - * `eval `__ + * `eval `__ * `sam `__ diff --git a/torchao/_models/README.md b/torchao/_models/README.md index 300f1ed7d3..a630e24073 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -33,9 +33,11 @@ sh benchmarks/_models/eval_hf_models.sh To run lm-eval for a different hf-model with AO quantization technique, run: ``` -python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +python -m benchmarks._models.llm_eval --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag ``` -Replace model id, quantization and tasks with your desired values Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques. +Replace model id, quantization and tasks with your desired values. Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques. + +**Note:** The unified `llm_eval.py` script supports both HuggingFace models (`--model_id`) and gpt-fast checkpoints (`--checkpoint_path`). Run `python -m benchmarks._models.llm_eval --help` to see all available options. # SAM2 sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc diff --git a/torchao/dtypes/floatx/README.md b/torchao/dtypes/floatx/README.md index 092ef01233..0a08157f9e 100644 --- a/torchao/dtypes/floatx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -49,7 +49,7 @@ outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape ## End-to-End benchmarks -Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in [_models/llama](../../_models/llama). tokens/s is measured using [generate.py](../../_models/llama/generate.py) which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using [eval.py](../../_models/llama/eval.py) which uses [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness). The model used is [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). +Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in [_models/llama](../../_models/llama). tokens/s is measured using [generate.py](../../_models/llama/generate.py) which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using [llm_eval.py](../../../benchmarks/_models/llm_eval.py) which uses [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness). The model used is [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). Floatx quantization is run with `--precision float16`. The rest uses the default precision of `bfloat16`. diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index a4b5d2801e..81d08c5c10 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -2,7 +2,7 @@ Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. Note: exact APIs are not stable, we may change them in the future. ## Benchmarks -Benchmarks and evaluation are gathered using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data on the meta-llama/Meta-Llama-3-8B model. +Benchmarks and evaluation are gathered using the scripts for [generation](../_models/llama/generate.py) and [eval](../../benchmarks/_models/llm_eval.py). Evaluation was done using the lm_eval library for tasks/data on the meta-llama/Meta-Llama-3-8B model. ### CUDA backend | NVIDIA-A100-80GB GPU | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | @@ -33,7 +33,7 @@ Benchmarks and evaluation are gathered using the scripts for [generation](../_mo | | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52 -Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are gathered using [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. +Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are gathered using [generation](../_models/llama/generate.py) and [eval](../../benchmarks/_models/llm_eval.py). Evaluation was done using the lm_eval library for tasks/data. note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. @@ -248,7 +248,7 @@ You try can out these apis with the `quantize_` api as above alongside the confi ### GPTQ Quantization We have a GPTQ quantization workflow that can be used to quantize a model to int4. More details can be found in [GPTQ](./GPTQ/README.md), -an example can be found in `torchao/_models/llama/eval.py`. +an example can be found in `benchmarks/_models/llm_eval.py` (use `--quantization int4wo-128-gptq`). ### Automatic Inductor Configuration