-
Notifications
You must be signed in to change notification settings - Fork 378
Align _choose_qparams_affine with _choose_scale_float8 behavior #3447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Align _choose_qparams_affine with _choose_scale_float8 behavior #3447
Conversation
Changes keepdim default from False to True in _choose_qparams_affine to match _choose_scale_float8 behavior. This ensures scale/zero_point maintain the same rank as input tensor, making downstream handling more consistent. Fixes pytorch#3324
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3447
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 13 New FailuresAs of commit bdf1210 with merge base aa21b80 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| # Reshape scale and zero_point to match expected output shape | ||
| # This aligns with _choose_scale_float8 behavior | ||
| if keepdim: | ||
| output_shape = [ | ||
| original_input_size[i] // block_size[i] for i in range(len(block_size)) | ||
| ] | ||
| scale = scale.reshape(output_shape) | ||
| zero_point = zero_point.reshape(output_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? keepdim=True is already used in int8_tensor.py
Edit: oh OK I think it is needed because the reshapes we are doing before
ao/torchao/quantization/quant_primitives.py
Line 1554 in aa21b80
| input = input.view(shape_for_reduction) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this reshape is needed! int8_tensor.py passes scale/zero_point directly to quantize_affine, which internally reshapes the scale at line 461. So it doesn't need the output to be pre-reshaped. but the thing is, IntxUnpackedToInt8Tensor.init (lines131-136) asserts that scale.shape must exactly match tuple(n_blocks) before passing to quantize_affine:
assert scale.shape == tuple(n_blocks), (
f"Expected scale to have shape {n_blocks} (inferred from
block_size={block_size}), but got {scale.shape}"
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically without this reshape:
- keepdim=True gives scale shape like (1, 5, 1) for block_size (10, 4) on input (10, 20)
- But IntxUnpackedToInt8Tensor expects (1, 5)
- The assertion would fail!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's correct, I'll approve the CI to run all the tests to see, especially this one:
ao/test/quantization/test_quant_primitives.py
Line 569 in aa21b80
| block_size = (3, 3, 2, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also in a future PR we could remove some of the reshaping logic in quantize_affine/dequantize _affine as well:
ao/torchao/quantization/quant_primitives.py
Lines 453 to 461 in aa21b80
| shape_for_reduction, reduction_dims = _get_reduction_params( | |
| block_size, input.size() | |
| ) | |
| original_shape = input.shape | |
| input = input.view(shape_for_reduction) | |
| shape_after_reduction = shape_for_reduction | |
| for i in reduction_dims: | |
| shape_after_reduction[i] = 1 | |
| scale = scale.view(shape_after_reduction) |
also eventually remove the block_size arg from these ops (bc-breaking)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, sounds like a plan, happy to contribute
|
Thanks, I think it's a good start, we can remove |
|
I see 25 integration tests failed due to backward compatibility issues with the |
|
it's expected, I think maybe just don't change the default for now, but turn the keepdim to True in these tests one by one to make sure these tests are fixed, and alls the callsites are fixed before making the switch would be better |
Changes keepdim default from False to True in _choose_qparams_affine to match _choose_scale_float8 behavior. This ensures scale/zero_point maintain the same rank as input tensor, making downstream handling more consistent.
Part 1 of fixing #3324
Changes
Core Changes (
torchao/quantization/quant_primitives.py)keepdim: bool = False→keepdim: bool = Truein bothchoose_qparams_affine(line 1220) and_choose_qparams_affine(line 1526)_choose_scale_float8behaviororiginal_input_sizebefore reshaping to compute correct output shape_choose_scale_float8Workflow Simplification (
torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py)Test Updates(
test/quantization/test_quant_primitives.py)test_choose_qparamstests now pass