Skip to content

Commit 5d0ff1a

Browse files
committed
update to rc2
1 parent fc11d70 commit 5d0ff1a

File tree

8 files changed

+30
-8
lines changed

8 files changed

+30
-8
lines changed

onnxruntime/test/python/quantization/test_qdq.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@
2828
from onnxruntime.quantization.quant_utils import quantize_nparray
2929

3030

31+
# TODO(titaiwang and justinchuby): What is the recommendation here after onnx deleted this function?
32+
def _unpack_single_4bitx2(x: np.ndarray | np.dtype | float, signed: bool) -> tuple[np.ndarray, np.ndarray]:
33+
def unpack_signed(x):
34+
return np.where((x >> 3) == 0, x, x | 0xF0)
35+
36+
"""Unpack a single byte 4bitx2 to two 4 bit elements
37+
Args:
38+
x: Input data
39+
signed: boolean, whether to interpret as signed int4.
40+
Returns:
41+
A tuple of ndarrays containing int4 elements (sign-extended to int8/uint8)
42+
"""
43+
if not isinstance(x, np.ndarray):
44+
x = np.asarray(x)
45+
x_low = x & 0x0F
46+
x_high = x >> 4
47+
x_low = unpack_signed(x_low) if signed else x_low
48+
x_high = unpack_signed(x_high) if signed else x_high
49+
dtype = np.int8 if signed else np.uint8
50+
return (x_low.astype(dtype), x_high.astype(dtype))
51+
52+
3153
class TestQDQFormat(unittest.TestCase):
3254
def input_feeds(self, n, name2shape, np_float_type=np.float32):
3355
input_data_list = []
@@ -1653,7 +1675,7 @@ def test_int4_qdq_conv(self):
16531675
float_data = weight_data.flatten().tolist()
16541676
for index, float_val in enumerate(float_data):
16551677
expected_int4_val = np.clip(np.float32(float_val / scale_val).round() + zp_val, -8, 7)
1656-
int4_pair = onnx.subbyte.unpack_single_4bitx2(weight_quant_init.raw_data[index >> 1], True)
1678+
int4_pair = _unpack_single_4bitx2(weight_quant_init.raw_data[index >> 1], True)
16571679
int4_val = int4_pair[index & 0x1]
16581680

16591681
self.assertEqual(np.float32(int4_val), expected_int4_val)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
onnx==1.20.0rc1
1+
onnx==1.20.0rc2
22
pytest
33
onnx-ir

tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ wheel
77
protobuf==4.25.8
88
sympy==1.14
99
flatbuffers
10-
onnx==1.20.0rc1; python_version < "3.14"
10+
onnx==1.20.0rc2; python_version < "3.14"

tools/ci_build/github/linux/docker/scripts/lort/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ beartype==0.15.0
33
flatbuffers
44
cerberus
55
h5py
6-
onnx==1.20.0rc1; python_version < "3.14"
6+
onnx==1.20.0rc2; python_version < "3.14"
77
# Python dependencies required for pytorch development
88
astunparse
99
expecttest!=0.2.0

tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ flatbuffers
1111
neural-compressor>=2.2.1
1212
triton==3.2.0; python_version < "3.14"
1313
triton==3.5.0; python_version >= "3.14"
14-
onnx==1.20.0rc1; python_version < "3.14"
14+
onnx==1.20.0rc2; python_version < "3.14"

tools/ci_build/github/linux/docker/scripts/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ protobuf==6.33.0; python_version >= "3.14"
1313
packaging
1414
onnxscript==0.5.3; python_version < "3.14"
1515
onnx-ir==0.1.10; python_version < "3.14"
16-
onnx==1.20.0rc1; python_version < "3.14"
16+
onnx==1.20.0rc2; python_version < "3.14"

tools/ci_build/github/linux/python/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ onnxscript==0.5.3; python_version < "3.14"
1313
onnx-ir==0.1.10; python_version < "3.14"
1414
jinja2
1515
markupsafe
16-
onnx==1.20.0rc1; python_version < "3.14"
16+
onnx==1.20.0rc2; python_version < "3.14"

tools/ci_build/github/windows/python/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ markupsafe
1616
semver
1717
packaging
1818
coloredlogs
19-
onnx==1.20.0rc1; python_version < "3.14"
19+
onnx==1.20.0rc2; python_version < "3.14"

0 commit comments

Comments
 (0)