Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Dec 8, 2025

Summary:
Add HQQ support for Int4PlainInt32Tensor, which is the W4A-INT quantization API for XPU/NPU devices.

Related Issue/PR: #3013

Test plan: CI

Future Plan:
Although HQQ showed better performance compared to TinyGEMM in Int4TilePackedTo4dTensor, this algorithm is not used as a default for XPU/NPU yet because PERF analysis (TinyGEMM vs. HQQ) is unavailable locally. We can set HQQ as a default after the PERF test.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3465

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 8, 2025
@namgyu-youn
Copy link
Contributor Author

@pytorchbot label "ciflow/xpu"

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 8, 2025

To add these label(s) (ciflow/xpu) to the PR, please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@liangan1
Copy link
Collaborator

liangan1 commented Dec 9, 2025

@namgyu-youn per my understand, the hqq use the float zero points, but the xpu only support int zp now

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Dec 9, 2025

@namgyu-youn per my understand, the hqq use the float zero points, but the xpu only support int zp now

Thanks I missed it. raw_output arguments in the HQQ algorithm seem to handle zero-point dtype:

  • If raw_output=True: returns INT zero point
  • If raw_output=False: return FP zero point

I also tested the following code locally for verification. Could you review this PR?

import torch

from torchao.quantization.quant_primitives import (
    _choose_qparams_and_quantize_affine_hqq,
)


def test_hqq_raw_output():
    """Test raw_output=True (INT domain) vs raw_output=False (FLOAT domain)."""
    w = torch.randn(256, 128, dtype=torch.bfloat16, device="cpu")

    # Test INT domain (XPU)
    _, _, zero_int, _ = _choose_qparams_and_quantize_affine_hqq(
        w,
        nbits=4,
        group_size=64,
        axis=1,
        compute_dtype=torch.bfloat16,
        device="cpu",
        verbose=False,
        raw_output=True,
    )

    # Test FLOAT domain (NPU)
    _, _, zero_float, _ = _choose_qparams_and_quantize_affine_hqq(
        w,
        nbits=4,
        group_size=64,
        axis=1,
        compute_dtype=torch.bfloat16,
        device="cpu",
        verbose=False,
        raw_output=False,
    )

    # Validate INT domain (0-15 range, int8 compatible)
    int_valid = (zero_int.min() >= 0) and (zero_int.max() <= 15)
    int8_works = (zero_int.to(torch.int8).min() >= 0) and (
        zero_int.to(torch.int8).max() <= 15
    )

    # Validate FLOAT domain (not integers)
    float_valid = not torch.allclose(
        zero_float.cpu().float(), zero_float.cpu().float().round(), atol=1e-3
    )

    print(
        f"raw_output=True (XPU):  zeros in [0,15]={int_valid}, int8 safe={int8_works}"
    )
    print(f"raw_output=False (NPU): float domain={float_valid}")

    return int_valid and int8_works and float_valid


if __name__ == "__main__":
    success = test_hqq_raw_output()
    exit(0 if success else 1)

And the result is:

> python test.py
raw_output=True (XPU):  zeros in [0,15]=True, int8 safe=True
raw_output=False (NPU): float domain=True

@liangan1 liangan1 added ciflow/xpu label used to trigger xpu CI jobs topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) labels Dec 9, 2025
@xiaowangintel
Copy link
Collaborator

@namgyu-youn per my understand, the hqq use the float zero points, but the xpu only support int zp now

Thanks I missed it. raw_output arguments in the HQQ algorithm seem to handle zero-point dtype:

  • If raw_output=True: returns INT zero point
  • If raw_output=False: return FP zero point

I also tested the following code locally for verification. Could you review this PR?

import torch

from torchao.quantization.quant_primitives import (
    _choose_qparams_and_quantize_affine_hqq,
)


def test_hqq_raw_output():
    """Test raw_output=True (INT domain) vs raw_output=False (FLOAT domain)."""
    w = torch.randn(256, 128, dtype=torch.bfloat16, device="cpu")

    # Test INT domain (XPU)
    _, _, zero_int, _ = _choose_qparams_and_quantize_affine_hqq(
        w,
        nbits=4,
        group_size=64,
        axis=1,
        compute_dtype=torch.bfloat16,
        device="cpu",
        verbose=False,
        raw_output=True,
    )

    # Test FLOAT domain (NPU)
    _, _, zero_float, _ = _choose_qparams_and_quantize_affine_hqq(
        w,
        nbits=4,
        group_size=64,
        axis=1,
        compute_dtype=torch.bfloat16,
        device="cpu",
        verbose=False,
        raw_output=False,
    )

    # Validate INT domain (0-15 range, int8 compatible)
    int_valid = (zero_int.min() >= 0) and (zero_int.max() <= 15)
    int8_works = (zero_int.to(torch.int8).min() >= 0) and (
        zero_int.to(torch.int8).max() <= 15
    )

    # Validate FLOAT domain (not integers)
    float_valid = not torch.allclose(
        zero_float.cpu().float(), zero_float.cpu().float().round(), atol=1e-3
    )

    print(
        f"raw_output=True (XPU):  zeros in [0,15]={int_valid}, int8 safe={int8_works}"
    )
    print(f"raw_output=False (NPU): float domain={float_valid}")

    return int_valid and int8_works and float_valid


if __name__ == "__main__":
    success = test_hqq_raw_output()
    exit(0 if success else 1)

And the result is:

> python test.py
raw_output=True (XPU):  zeros in [0,15]=True, int8 safe=True
raw_output=False (NPU): float domain=True

I don't understand why raw_output=True is related to XPU? Why is this also the case for NPU?

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Dec 9, 2025

I don't understand why raw_output=True is related to XPU? Why is this also the case for NPU?

Let me make it clear: the raw_output value matches what each device's kernel expects, not the device type itself.

  • XPU (torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros): zp dtype=torch.int32 → INT, converted to int8
  • NPU (torch.ops.npu.npu_weight_quant_batchmatmul): zp dtype=w.dtype → FLOAT (fp16/bf16)

That means the kernel for XPU requires INT (zp; zero point), while the kernel for NPU requires FLOAT (zp), as you mentioned before. Therefore, we can follow the expected dtype, by using raw_output arguments:

  • raw_output=True: returns INT domain zero point
  • raw_output=False: return FLOAT domain zero point

@pytorch-bot pytorch-bot bot removed the ciflow/xpu label used to trigger xpu CI jobs label Dec 9, 2025
@namgyu-youn namgyu-youn changed the title Add HQQ support for XPU/NPU int4 quantization Support HQQ for XPU/NPU int4 quantization Dec 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants