diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 4bc3236759..da73e7d546 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -302,11 +302,15 @@ def test_choose_qparams_token_asym(self): input, dtype ) ) + # With keepdim=True, scale and zero_point now keep dimensions + # Match reference shapes for comparison scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() + scale_squeezed = scale.squeeze() + zp_squeezed = zero_point.squeeze() - torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3) - self.assertTrue(torch.equal(zero_point, zp_ref)) + torch.testing.assert_close(scale_squeezed, scale_ref, atol=10e-3, rtol=10e-3) + self.assertTrue(torch.equal(zp_squeezed, zp_ref)) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_tensor_asym(self): @@ -324,11 +328,14 @@ def test_choose_qparams_tensor_asym(self): scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams( input, quant_min, quant_max, eps, dtype ) + # With keepdim=True, scale and zero_point now keep dimensions scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() + scale_squeezed = scale.squeeze() + zp_squeezed = zero_point.squeeze() - self.assertTrue(torch.equal(scale, scale_ref)) - self.assertTrue(torch.equal(zero_point, zp_ref)) + self.assertTrue(torch.equal(scale_squeezed, scale_ref)) + self.assertTrue(torch.equal(zp_squeezed, zp_ref)) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_tensor_sym(self): @@ -346,11 +353,14 @@ def test_choose_qparams_tensor_sym(self): scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric( input, quant_min, quant_max, eps, dtype ) + # With keepdim=True, scale and zero_point now keep dimensions scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() + scale_squeezed = scale.squeeze() + zp_squeezed = zero_point.squeeze() - self.assertTrue(torch.equal(scale, scale_ref)) - self.assertTrue(torch.equal(zero_point, zp_ref)) + self.assertTrue(torch.equal(scale_squeezed, scale_ref)) + self.assertTrue(torch.equal(zp_squeezed, zp_ref)) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 9bdb3871a2..f308e05613 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1217,7 +1217,7 @@ def choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = torch.int32, - keepdim: bool = False, + keepdim: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -1231,6 +1231,7 @@ def choose_qparams_affine( eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype scale_dtype (torch.dtype): dtype for scale Tensor zero_point_dtype (torch.dtype): dtype for zero_point Tensor, defaults to torch.int32 + keepdim (bool): whether to keep dimensions with size 1 in output (aligned with _choose_scale_float8) Now removed params: zero_point_domain (ZeroPointDomain): the domain that zero_point is in, defaults to Integer or None preserve_zero (bool): whether to preserve zero in the quantized Tensor, defaults to True @@ -1523,7 +1524,7 @@ def _choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, - keepdim: bool = False, + keepdim: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """op definition that has compatible signatures with custom op library @@ -1532,6 +1533,10 @@ def _choose_qparams_affine( 2. find min_val/max_val based on the dimension for reduction 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` and `zero_point_domain` + + Note: + keepdim defaults to True to align with _choose_scale_float8 behavior. This ensures + scale/zero_point maintain the same rank as input, making it easier to handle downstream. """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) assert mapping_type in [ @@ -1548,6 +1553,8 @@ def _choose_qparams_affine( assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" ) + # Save original input size before reshaping for later use + original_input_size = input.size() shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input.size() ) @@ -1591,6 +1598,15 @@ def _choose_qparams_affine( if zero_point_dtype is None: zero_point_dtype = torch.int32 + # Reshape scale and zero_point to match expected output shape + # This aligns with _choose_scale_float8 behavior + if keepdim: + output_shape = [ + original_input_size[i] // block_size[i] for i in range(len(block_size)) + ] + scale = scale.reshape(output_shape) + zero_point = zero_point.reshape(output_shape) + return scale.to(dtype=scale_dtype, device=input.device), zero_point.to( dtype=zero_point_dtype ) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 88ad165ecf..1665e2ee24 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -244,14 +244,8 @@ def from_hp( f"Unsupported IntxChooseQParamsAlgorithm: {intx_choose_qparams_algorithm}" ) - # Reshape scale and zero_point to be compatible with block_size - # This is asserted in IntxUnpackedToInt8Tensor's __init__ - n_blocks = [] - for i in range(len(block_size)): - assert qdata.shape[i] % block_size[i] == 0 - n_blocks.append(qdata.shape[i] // block_size[i]) - scale = scale.reshape(*n_blocks) - zero_point = zero_point.reshape(*n_blocks) + # Note: scale and zero_point already have the correct shape from choose_qparams_affine + # which now uses keepdim=True and reshapes to match block_size expectations return IntxUnpackedToInt8Tensor( qdata=qdata,