Skip to content

Commit 57c3b8e

Browse files
dsharletgxnnpack-bot
authored andcommitted
Improve reference implementation of quantization
- Add a helper `clamp_float_to_int` that avoids issues with converting large integers to floats and losing information. - Add more information when tests fail. PiperOrigin-RevId: 839514077
1 parent 7c72232 commit 57c3b8e

File tree

4 files changed

+23
-15
lines changed

4 files changed

+23
-15
lines changed

ynnpack/base/arithmetic.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,25 @@
1313
#include <cstdint>
1414
#include <limits>
1515

16-
#include "ynnpack/base/base.h"
1716
#include "ynnpack/base/type.h"
1817

1918
namespace ynn {
2019

20+
// Clamp a float to the range of the given integer or quantized integer type.
21+
template <typename Int>
22+
float clamp_float_to_int(float x) {
23+
using Unwrapped = typename unwrap_quantized<Int>::type;
24+
// It's tricky to do this with std::max/std::min, because the min/max values
25+
// might not be exactly representable as floats, and so are ineffective to
26+
// avoid converting to an out of bounds integer. To avoid this problem, we've
27+
// determined a constant that when added to the min/max float values, results
28+
// in the upper bound of the integer range.
29+
constexpr int half_mantissa = sizeof(Unwrapped) * 8 > 23 ? 127 : 0;
30+
x = std::max<float>(x, std::numeric_limits<Unwrapped>::min());
31+
x = std::min<float>(x, std::numeric_limits<Unwrapped>::max() - half_mantissa);
32+
return x;
33+
}
34+
2135
// A cast that:
2236
// - Rounds to nearest integer
2337
// - Replaces NaN with 0
@@ -27,14 +41,7 @@ Result round_float_to_int(float x) {
2741
using Unwrapped = typename unwrap_quantized<Result>::type;
2842
x = std::isnan(x) ? 0.0f : x;
2943
x = std::round(x);
30-
// It's tricky to do this with std::max/std::min, because the min/max values
31-
// might not be exactly representable as floats, and so are ineffective to
32-
// avoid converting to an out of bounds integer. To avoid this problem, we've
33-
// determined a constant that when added to the min/max float values, results
34-
// in the upper bound of the integer range.
35-
constexpr int half_mantissa = sizeof(Unwrapped) * 8 > 23 ? 127 : 0;
36-
x = std::max<float>(x, std::numeric_limits<Unwrapped>::min());
37-
x = std::min<float>(x, std::numeric_limits<Unwrapped>::max() - half_mantissa);
44+
x = clamp_float_to_int<Result>(x);
3845
return static_cast<Unwrapped>(x);
3946
}
4047

ynnpack/kernels/binary/reference.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ void check_results(const OpInfo& op, const Tensor<quantized<A>>& a,
152152
const float b_i = dequantize(b(i), b_quantization);
153153
float expected = op(a_i, b_i);
154154
expected = fake_quantize(expected, x_quantization);
155-
expected = std::max<float>(expected, type_info<X>::min());
156-
expected = std::min<float>(expected, type_info<X>::max());
155+
expected = clamp_float_to_int<X>(expected);
157156
if (std::isnan(expected)) {
158157
// We don't know how to represent NaN for quantized types.
159158
} else {

ynnpack/kernels/unary/reference.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,15 +473,15 @@ void check_results(const unary_op_info& op, Tensor<A> a, Tensor<X> x,
473473
const float input_i = dequantize(a(i), a_quantization);
474474
float expected = op(input_i);
475475
expected = fake_quantize(expected, x_quantization);
476-
expected = std::max<float>(expected, type_info<X>::min());
477-
expected = std::min<float>(expected, type_info<X>::max());
476+
expected = clamp_float_to_int<X>(expected);
478477
if (std::isnan(expected)) {
479478
// This is expected to overflow.
480479
} else {
481480
ASSERT_NEAR(expected, x(i), op.tolerance(expected, type_of<X>()))
482481
<< "i = " << index_to_string(i) << ", a(i) = " << input_i << " ("
483482
<< static_cast<float>(a(i)) << ")"
484-
<< ", x(i) = " << static_cast<int32_t>(x(i));
483+
<< ", x(i) = " << static_cast<int32_t>(x(i)) << " ("
484+
<< dequantize(x(i), x_quantization) << ")" << std::endl;
485485
}
486486
} else {
487487
const float input_i = dequantize(a(i), a_quantization);

ynnpack/xnnpack/dynamic_quantization_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ void TestImpl(T, size_t rank) {
8383
broadcast_extent_1(zero_point);
8484
for (const auto& i : EnumerateIndices(shape)) {
8585
ASSERT_NEAR(quantize<int8_t>(input(i), 1.0f / scale(i), zero_point(i)),
86-
output(i), 1);
86+
output(i), 1)
87+
<< "input=" << input(i) << ", scale=" << scale(i)
88+
<< ", zero_point=" << zero_point(i);
8789
}
8890
}
8991
}

0 commit comments

Comments
 (0)