Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,15 @@ def test_choose_qparams_token_asym(self):
input, dtype
)
)
# With keepdim=True, scale and zero_point now keep dimensions
# Match reference shapes for comparison
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()
scale_squeezed = scale.squeeze()
zp_squeezed = zero_point.squeeze()

torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))
torch.testing.assert_close(scale_squeezed, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zp_squeezed, zp_ref))

@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_asym(self):
Expand All @@ -324,11 +328,14 @@ def test_choose_qparams_tensor_asym(self):
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(
input, quant_min, quant_max, eps, dtype
)
# With keepdim=True, scale and zero_point now keep dimensions
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()
scale_squeezed = scale.squeeze()
zp_squeezed = zero_point.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))
self.assertTrue(torch.equal(scale_squeezed, scale_ref))
self.assertTrue(torch.equal(zp_squeezed, zp_ref))

@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_sym(self):
Expand All @@ -346,11 +353,14 @@ def test_choose_qparams_tensor_sym(self):
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(
input, quant_min, quant_max, eps, dtype
)
# With keepdim=True, scale and zero_point now keep dimensions
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()
scale_squeezed = scale.squeeze()
zp_squeezed = zero_point.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))
self.assertTrue(torch.equal(scale_squeezed, scale_ref))
self.assertTrue(torch.equal(zp_squeezed, zp_ref))

def test_quantize_activation_per_token_abs_max(self):
input = torch.randn(10, 10)
Expand Down
20 changes: 18 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = torch.int32,
keepdim: bool = False,
keepdim: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand All @@ -1231,6 +1231,7 @@ def choose_qparams_affine(
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
scale_dtype (torch.dtype): dtype for scale Tensor
zero_point_dtype (torch.dtype): dtype for zero_point Tensor, defaults to torch.int32
keepdim (bool): whether to keep dimensions with size 1 in output (aligned with _choose_scale_float8)
Now removed params:
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, defaults to Integer or None
preserve_zero (bool): whether to preserve zero in the quantized Tensor, defaults to True
Expand Down Expand Up @@ -1523,7 +1524,7 @@ def _choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
keepdim: bool = False,
keepdim: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library

Expand All @@ -1532,6 +1533,10 @@ def _choose_qparams_affine(
2. find min_val/max_val based on the dimension for reduction
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
and `zero_point_domain`

Note:
keepdim defaults to True to align with _choose_scale_float8 behavior. This ensures
scale/zero_point maintain the same rank as input, making it easier to handle downstream.
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [
Expand All @@ -1548,6 +1553,8 @@ def _choose_qparams_affine(
assert len(block_size) == input.dim(), (
f"Got input dim:{input.dim()}, block_size: {block_size}"
)
# Save original input size before reshaping for later use
original_input_size = input.size()
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
Expand Down Expand Up @@ -1591,6 +1598,15 @@ def _choose_qparams_affine(
if zero_point_dtype is None:
zero_point_dtype = torch.int32

# Reshape scale and zero_point to match expected output shape
# This aligns with _choose_scale_float8 behavior
if keepdim:
output_shape = [
original_input_size[i] // block_size[i] for i in range(len(block_size))
]
scale = scale.reshape(output_shape)
zero_point = zero_point.reshape(output_shape)
Comment on lines +1601 to +1608
Copy link
Contributor

@jerryzh168 jerryzh168 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? keepdim=True is already used in int8_tensor.py

Edit: oh OK I think it is needed because the reshapes we are doing before

input = input.view(shape_for_reduction)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this reshape is needed! int8_tensor.py passes scale/zero_point directly to quantize_affine, which internally reshapes the scale at line 461. So it doesn't need the output to be pre-reshaped. but the thing is, IntxUnpackedToInt8Tensor.init (lines131-136) asserts that scale.shape must exactly match tuple(n_blocks) before passing to quantize_affine:

  assert scale.shape == tuple(n_blocks), ( 
  f"Expected scale to have shape {n_blocks} (inferred from
  block_size={block_size}), but got {scale.shape}"
  )

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically without this reshape:

  • keepdim=True gives scale shape like (1, 5, 1) for block_size (10, 4) on input (10, 20)
  • But IntxUnpackedToInt8Tensor expects (1, 5)
  • The assertion would fail!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's correct, I'll approve the CI to run all the tests to see, especially this one:

block_size = (3, 3, 2, 2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also in a future PR we could remove some of the reshaping logic in quantize_affine/dequantize _affine as well:

shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
original_shape = input.shape
input = input.view(shape_for_reduction)
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
scale = scale.view(shape_after_reduction)

also eventually remove the block_size arg from these ops (bc-breaking)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, sounds like a plan, happy to contribute


return scale.to(dtype=scale_dtype, device=input.device), zero_point.to(
dtype=zero_point_dtype
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,8 @@ def from_hp(
f"Unsupported IntxChooseQParamsAlgorithm: {intx_choose_qparams_algorithm}"
)

# Reshape scale and zero_point to be compatible with block_size
# This is asserted in IntxUnpackedToInt8Tensor's __init__
n_blocks = []
for i in range(len(block_size)):
assert qdata.shape[i] % block_size[i] == 0
n_blocks.append(qdata.shape[i] // block_size[i])
scale = scale.reshape(*n_blocks)
zero_point = zero_point.reshape(*n_blocks)
# Note: scale and zero_point already have the correct shape from choose_qparams_affine
# which now uses keepdim=True and reshapes to match block_size expectations

return IntxUnpackedToInt8Tensor(
qdata=qdata,
Expand Down
Loading