Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
edbfae5
initial
MekkCyber Nov 27, 2025
53a2f59
quantization fixed
MekkCyber Nov 27, 2025
29febd4
up
MekkCyber Nov 28, 2025
5824ecb
working
MekkCyber Dec 4, 2025
f86696f
fix
MekkCyber Dec 4, 2025
58b8c99
style
MekkCyber Dec 4, 2025
2e02c2d
clean
MekkCyber Dec 4, 2025
9072c05
reset
MekkCyber Dec 4, 2025
97a8293
style
MekkCyber Dec 4, 2025
659398d
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 4, 2025
c162fdc
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 4, 2025
49ad55a
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 5, 2025
f6c8ab3
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 5, 2025
ebc5d13
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 5, 2025
2be357d
Merge branch 'main' into fix-fp-quant
SunMarc Dec 5, 2025
cc74946
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 5, 2025
4e041c0
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 5, 2025
1fb8bcf
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 8, 2025
3b5077f
rm duplicate
MekkCyber Dec 8, 2025
5276aee
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 8, 2025
b0ec2fe
Merge branch 'main' into fix-fp-quant
SunMarc Dec 8, 2025
a01527e
Merge branch 'main' into fix-fp-quant
SunMarc Dec 8, 2025
bac4828
Merge branch 'main' into fix-fp-quant
SunMarc Dec 8, 2025
2137ce9
Merge branch 'main' into fix-fp-quant
SunMarc Dec 8, 2025
24e1247
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 8, 2025
ebd8ad9
ci: empty commit
MekkCyber Dec 8, 2025
77c4c0d
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 9, 2025
e39639b
Merge branch 'main' into fix-fp-quant
SunMarc Dec 9, 2025
e415303
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 9, 2025
f8ed9d6
Merge branch 'main' into fix-fp-quant
SunMarc Dec 9, 2025
c5754d2
Merge branch 'main' into fix-fp-quant
SunMarc Dec 9, 2025
47d3b93
Merge branch 'main' into fix-fp-quant
MekkCyber Dec 10, 2025
946c2dc
Merge branch 'main' into fix-fp-quant
SunMarc Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def _build_checkpoint_conversion_mapping():
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
mapping["legacy"] += [
WeightRenaming(
source_patterns="weight_g",
source_patterns=r"weight_g$",
target_patterns="parametrizations.weight.original0",
),
WeightRenaming(
source_patterns="weight_v",
source_patterns=r"weight_v$",
target_patterns="parametrizations.weight.original1",
),
]
Expand Down
92 changes: 92 additions & 0 deletions src/transformers/integrations/fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.
"FP-Quant integration file"

from typing import Optional

import torch

from ..utils import (
is_fp_quant_available,
)
Expand All @@ -24,6 +28,94 @@

from transformers.utils.quantization_config import FPQuantConfig

from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name


class FpQuantQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict: torch.Tensor,
model: Optional[torch.nn.Module] = None,
missing_keys: Optional[list[str]] = None,
**kwargs,
) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0]
# Loading master weights or an unquantized checkpoint
weight = torch.nn.Parameter(value)
module, _ = get_module_from_name(model, target_key)
module.weight = weight

# Let pre-forward handle the quantization and set None where necessary
# This operation will quantize the weights internally
with torch.cuda.device(value.device):
module.pre_forward()

prefix_target_key = target_key.rsplit(".", 1)[0]

# keys are set inside the module.pre_forward() method, we don't need remove them from the missing keys list
missing_keys.discard(target_key)
missing_keys.discard(f"{prefix_target_key}.backward_hadamard_matrix")
missing_keys.discard(f"{prefix_target_key}.forward_hadamard_matrix")
missing_keys.discard(f"{prefix_target_key}.act_global_scale")
missing_keys.discard(f"{prefix_target_key}.weight_global_scale")
missing_keys.discard(f"{prefix_target_key}.qweight")
missing_keys.discard(f"{prefix_target_key}.scales")
missing_keys.discard(f"{prefix_target_key}.dqweight")
return {}


class FpQuantDeserialize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict: torch.Tensor,
model: Optional[torch.nn.Module] = None,
full_layer_name: str | None = None,
missing_keys: Optional[list[str]] = None,
**kwargs,
) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
module, _ = get_module_from_name(model, target_key)
# The module holds either:
# * `weight` when `store_master_weights=True`
# * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
# * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
if target_key == ".qweight":
# Loading a real quantized checkpoint without master weights
qweight = torch.nn.Parameter(
value,
requires_grad=False,
)

return {
".qweight": qweight,
# the way the FPQuantLinear module is designed, these parameters are expected in the model
# even though they are not used so we need to set them to zeros
".weight": torch.nn.Parameter(torch.zeros(0)),
".dqweight": torch.nn.Parameter(torch.zeros(0)),
}

if target_key == ".dqweight":
# Loading a pseudo-quantized checkpoint without master weights
dqweight = torch.nn.Parameter(value)

return {
".dqweight": dqweight,
# the way the FPQuantLinear module ips designed, these parameters are expected in the model
# even though they are not used so we need to set them to zeros
".weight": torch.nn.Parameter(torch.zeros(0)),
".qweight": torch.nn.Parameter(torch.zeros(0)),
".scales": torch.nn.Parameter(torch.zeros(0)),
}


def adapt_fp_quant_config(config: FPQuantConfig):
if config.forward_dtype == "mxfp4":
Expand Down
33 changes: 29 additions & 4 deletions src/transformers/quantizers/quantizer_fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ def _process_model_before_weight_loading(

from ..integrations.fp_quant import adapt_fp_quant_config

replace_with_fp_quant_linear(
model,
fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
)
replace_with_fp_quant_linear(model, fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config))
model.config.quantization_config = self.quantization_config

def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
Expand Down Expand Up @@ -178,3 +175,31 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
return True
else:
return False

def get_quantize_ops(self):
from ..integrations.fp_quant import FpQuantQuantize

return FpQuantQuantize(self)

def get_weight_conversions(self):
from ..core_model_loading import WeightConverter
from ..integrations.fp_quant import FpQuantDeserialize

if self.pre_quantized:
if self.quantization_config.pseudoquantization:
return [
WeightConverter(
source_patterns=[".dqweight"],
target_patterns=".dqweight",
operations=[FpQuantDeserialize(self)],
),
]
else:
return [
WeightConverter(
source_patterns=[".qweight"],
target_patterns=".qweight",
operations=[FpQuantDeserialize(self)],
),
]
return []