|
28 | 28 | from onnxruntime.quantization.quant_utils import quantize_nparray |
29 | 29 |
|
30 | 30 |
|
| 31 | +# TODO(titaiwang and justinchuby): What is the recommendation here after onnx deleted this function? |
| 32 | +def _unpack_single_4bitx2(x: np.ndarray | np.dtype | float, signed: bool) -> tuple[np.ndarray, np.ndarray]: |
| 33 | + def unpack_signed(x): |
| 34 | + return np.where((x >> 3) == 0, x, x | 0xF0) |
| 35 | + |
| 36 | + """Unpack a single byte 4bitx2 to two 4 bit elements |
| 37 | + Args: |
| 38 | + x: Input data |
| 39 | + signed: boolean, whether to interpret as signed int4. |
| 40 | + Returns: |
| 41 | + A tuple of ndarrays containing int4 elements (sign-extended to int8/uint8) |
| 42 | + """ |
| 43 | + if not isinstance(x, np.ndarray): |
| 44 | + x = np.asarray(x) |
| 45 | + x_low = x & 0x0F |
| 46 | + x_high = x >> 4 |
| 47 | + x_low = unpack_signed(x_low) if signed else x_low |
| 48 | + x_high = unpack_signed(x_high) if signed else x_high |
| 49 | + dtype = np.int8 if signed else np.uint8 |
| 50 | + return (x_low.astype(dtype), x_high.astype(dtype)) |
| 51 | + |
| 52 | + |
31 | 53 | class TestQDQFormat(unittest.TestCase): |
32 | 54 | def input_feeds(self, n, name2shape, np_float_type=np.float32): |
33 | 55 | input_data_list = [] |
@@ -1653,7 +1675,7 @@ def test_int4_qdq_conv(self): |
1653 | 1675 | float_data = weight_data.flatten().tolist() |
1654 | 1676 | for index, float_val in enumerate(float_data): |
1655 | 1677 | expected_int4_val = np.clip(np.float32(float_val / scale_val).round() + zp_val, -8, 7) |
1656 | | - int4_pair = onnx.subbyte.unpack_single_4bitx2(weight_quant_init.raw_data[index >> 1], True) |
| 1678 | + int4_pair = _unpack_single_4bitx2(weight_quant_init.raw_data[index >> 1], True) |
1657 | 1679 | int4_val = int4_pair[index & 0x1] |
1658 | 1680 |
|
1659 | 1681 | self.assertEqual(np.float32(int4_val), expected_int4_val) |
|
0 commit comments