-
Notifications
You must be signed in to change notification settings - Fork 383
Support HQQ for XPU/NPU int4 quantization #3465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3465
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "ciflow/xpu" |
|
To add these label(s) (ciflow/xpu) to the PR, please first approve the workflows that are awaiting approval (scroll to the bottom of this page). This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
@namgyu-youn per my understand, the hqq use the float zero points, but the xpu only support int zp now |
Thanks I missed it.
I also tested the following code locally for verification. Could you review this PR? import torch
from torchao.quantization.quant_primitives import (
_choose_qparams_and_quantize_affine_hqq,
)
def test_hqq_raw_output():
"""Test raw_output=True (INT domain) vs raw_output=False (FLOAT domain)."""
w = torch.randn(256, 128, dtype=torch.bfloat16, device="cpu")
# Test INT domain (XPU)
_, _, zero_int, _ = _choose_qparams_and_quantize_affine_hqq(
w,
nbits=4,
group_size=64,
axis=1,
compute_dtype=torch.bfloat16,
device="cpu",
verbose=False,
raw_output=True,
)
# Test FLOAT domain (NPU)
_, _, zero_float, _ = _choose_qparams_and_quantize_affine_hqq(
w,
nbits=4,
group_size=64,
axis=1,
compute_dtype=torch.bfloat16,
device="cpu",
verbose=False,
raw_output=False,
)
# Validate INT domain (0-15 range, int8 compatible)
int_valid = (zero_int.min() >= 0) and (zero_int.max() <= 15)
int8_works = (zero_int.to(torch.int8).min() >= 0) and (
zero_int.to(torch.int8).max() <= 15
)
# Validate FLOAT domain (not integers)
float_valid = not torch.allclose(
zero_float.cpu().float(), zero_float.cpu().float().round(), atol=1e-3
)
print(
f"raw_output=True (XPU): zeros in [0,15]={int_valid}, int8 safe={int8_works}"
)
print(f"raw_output=False (NPU): float domain={float_valid}")
return int_valid and int8_works and float_valid
if __name__ == "__main__":
success = test_hqq_raw_output()
exit(0 if success else 1)And the result is: > python test.py
raw_output=True (XPU): zeros in [0,15]=True, int8 safe=True
raw_output=False (NPU): float domain=True |
I don't understand why raw_output=True is related to XPU? Why is this also the case for NPU? |
Let me make it clear: the
That means the kernel for XPU requires INT (zp; zero point), while the kernel for NPU requires FLOAT (zp), as you mentioned before. Therefore, we can follow the expected dtype, by using
|
Summary:
Add HQQ support for
Int4PlainInt32Tensor, which is the W4A-INT quantization API for XPU/NPU devices.Related Issue/PR: #3013
Test plan: CI
Future Plan:
Although HQQ showed better performance compared to TinyGEMM in
Int4TilePackedTo4dTensor, this algorithm is not used as a default for XPU/NPU yet because PERF analysis (TinyGEMM vs. HQQ) is unavailable locally. We can set HQQ as a default after the PERF test.