Skip to content

Commit 9550e4c

Browse files
committed
fix zero point dytpe for XPU+HQQ
1 parent 6db96c2 commit 9550e4c

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def _from_hp_xpu(
121121
quant_min = 0
122122
quant_max = 15
123123

124-
# 1. use HQQ (Half-Quadratic Quantization) algorithm to compute
125-
# scale and zero_point, then convert to the format that's compatible with XPU kernels
124+
# We support two algorithms for construction: HQQ (mostly better) and TinyGEMM
125+
# Both use the same XPU kernel (_weight_int4pack_mm_with_scales_and_zeros)
126126
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
127127
import math
128128

@@ -139,10 +139,9 @@ def _from_hp_xpu(
139139
compute_dtype=compute_dtype,
140140
device=device,
141141
verbose=False,
142-
raw_output=False,
142+
raw_output=True,
143143
)
144144
int_data = int_data.to(target_dtype)
145-
# 2. don't use HQQ, use default choose_qparams_affine algorithm to compute scale and zero_point
146145
else:
147146
assert int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM, (
148147
f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}"
@@ -226,8 +225,8 @@ def _from_hp_npu(
226225
quant_min = -8
227226
quant_max = 7
228227

229-
# 1. use HQQ (Half-Quadratic Quantization) algorithm to compute
230-
# scale and zero_point, then convert to the format that's compatible with XPU kernels
228+
# We support two algorithms for construction: HQQ (mostly better) and TinyGEMM
229+
# Both accept FLOAT zero points for NPU kernel (npu_weight_quant_batchmatmul)
231230
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
232231
import math
233232

0 commit comments

Comments
 (0)