Skip to content

Commit 4870d45

Browse files
rivkastrohCopilot
andauthored
Add int8 support to ConvInteger (#26585)
### Description <!-- Describe your changes. --> This change extends the `ConvInteger` implementation to match the [ONNX operator spec](https://onnx.ai/onnx/operators/onnx__ConvInteger.html), which allows both `int8` and `uint8` for the input tensors: - The ONNX `ConvInteger` schema defines: - `T1`: `tensor(int8)` or `tensor(uint8)` - `T2`: `tensor(int8)` or `tensor(uint8)` - `T3`: `tensor(int32)` - Previously, only the `uint8` × `uint8` combination was supported. - This PR adds support for all 8-bit combinations: - `uint8` × `uint8` (existing behavior) - `uint8` × `int8` - `int8` × `uint8` - `int8` × `int8` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fixes #24183 Fixes #15888 Fixes #12558 Fixes #3130 Fixes #12362 The ONNX ConvInteger operator schema allows both int8 and uint8 element types for its inputs, but the current implementation only supports uint8 × uint8. This leads to a gap where valid ONNX models using ConvInteger with int8 tensors cannot be executed. This PR closes that gap by: Aligning the implementation with the official ConvInteger type constraints. Enabling models that use int8 (or mixed int8/uint8) for X and W to run without needing operator rewrites or additional custom kernels. Keeping existing uint8 behavior unchanged, so the change is backwards compatible for current users. ### Implementation details 1. Templated core implementation (ComputeInner) The core logic of ConvInteger::Compute is moved into a templated helper: ```text class ConvInteger : public OpKernel { public: ... private: template <typename XT, typename WT> Status ComputeInner(OpKernelContext* context) const }; ``` XT is the element type of X (uint8_t or int8_t). WT is the element type of W (uint8_t or int8_t). 2. Zero-point handling Zero points are still treated as per-tensor scalar values, with the same validation, The values are read via `DataRaw()` and stored as 8-bit scalars, preserving the previous behavior. Interpretation of these raw bytes as signed or unsigned is delegated to the GEMM implementation via explicit signedness flags (see below). 3. Im2col templated on XT The Im2col call now uses the runtime input type XT. 4. Quantized GEMM with signedness flags: ```text gemm_shape.AIsSigned = W->IsDataType<int8_t>(); gemm_shape.BIsSigned = X->IsDataType<int8_t>(); ``` AIsSigned and BIsSigned are derived from the runtime types of W and X. Data for A and B is passed as raw bytes, the GEMM implementation uses the signedness flags to interpret them correctly (In a manner similar to the implementation in `MatMulInteger`). 5. Runtime dispatch in Compute() The public Compute method becomes a thin dispatcher that selects the appropriate ComputeInner<XT, WT> instantiation based on the actual input types. In addition, a small set of unit tests is added on top of the existing ConvInteger tests to cover the new type combinations, including cases where the first input tensor contains negative values (for the int8 × int8 path). --------- Co-authored-by: Copilot <[email protected]>
1 parent 977efe4 commit 4870d45

File tree

4 files changed

+669
-38
lines changed

4 files changed

+669
-38
lines changed

docs/OperatorKernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Do not modify directly.*
8888
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|22+|**T** = tensor(float)|
8989
|||[11, 21]|**T** = tensor(float)|
9090
|||[1, 10]|**T** = tensor(float)|
91-
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int32)|
91+
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int32)|
9292
|ConvTranspose|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|22+|**T** = tensor(float)|
9393
|||[11, 21]|**T** = tensor(float)|
9494
|||[1, 10]|**T** = tensor(float)|

onnxruntime/core/providers/cpu/quantization/conv_integer.cc

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ ONNX_OPERATOR_KERNEL_EX(
2828
10,
2929
kCpuExecutionProvider,
3030
KernelDefBuilder()
31-
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
32-
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
31+
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<uint8_t>(),
32+
DataTypeImpl::GetTensorType<int8_t>()})
33+
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(),
34+
DataTypeImpl::GetTensorType<int8_t>()})
3335
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
3436
ConvInteger);
3537

@@ -43,12 +45,12 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
4345
if (num_inputs >= 3 && input_defs[2]->Exists()) {
4446
const auto* X_Zero_Point = context->Input<Tensor>(2);
4547
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
46-
input_offset = *(X_Zero_Point->Data<uint8_t>());
48+
input_offset = *static_cast<const uint8_t*>(X_Zero_Point->DataRaw());
4749
}
4850
if (num_inputs >= 4 && input_defs[3]->Exists()) {
4951
const auto* W_Zero_Point = context->Input<Tensor>(3);
5052
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
51-
filter_offset = *(W_Zero_Point->Data<uint8_t>());
53+
filter_offset = *static_cast<const uint8_t*>(W_Zero_Point->DataRaw());
5254
}
5355

5456
const int64_t N = X->Shape()[0];
@@ -110,58 +112,97 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
110112

111113
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
112114

113-
const auto* Xdata = X->Data<uint8_t>();
114-
const auto* Wdata = W->Data<uint8_t>();
115+
const auto* Xdata = static_cast<const uint8_t*>(X->DataRaw());
116+
const auto* Wdata = static_cast<const uint8_t*>(W->DataRaw());
117+
bool X_is_signed = X->IsDataType<int8_t>();
115118
auto* Ydata = Y->MutableData<int32_t>();
116119

117120
for (int image_id = 0; image_id < N; ++image_id) {
118121
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
119122
if (col_buffer_data != nullptr) {
120123
if (kernel_rank == 2) {
121-
math::Im2col<uint8_t, StorageOrder::NCHW>()(
122-
Xdata,
123-
C / conv_attrs_.group,
124-
input_shape[0],
125-
input_shape[1],
126-
kernel_shape[0],
127-
kernel_shape[1],
128-
dilations[0],
129-
dilations[1],
130-
pads[0],
131-
pads[1],
132-
pads[2],
133-
pads[3],
134-
strides[0],
135-
strides[1],
136-
col_buffer_data,
137-
input_offset);
124+
if (X_is_signed) {
125+
math::Im2col<int8_t, StorageOrder::NCHW>()(
126+
reinterpret_cast<const int8_t*>(Xdata),
127+
C / conv_attrs_.group,
128+
input_shape[0],
129+
input_shape[1],
130+
kernel_shape[0],
131+
kernel_shape[1],
132+
dilations[0],
133+
dilations[1],
134+
pads[0],
135+
pads[1],
136+
pads[2],
137+
pads[3],
138+
strides[0],
139+
strides[1],
140+
reinterpret_cast<int8_t*>(col_buffer_data),
141+
static_cast<int8_t>(input_offset));
142+
} else {
143+
math::Im2col<uint8_t, StorageOrder::NCHW>()(
144+
Xdata,
145+
C / conv_attrs_.group,
146+
input_shape[0],
147+
input_shape[1],
148+
kernel_shape[0],
149+
kernel_shape[1],
150+
dilations[0],
151+
dilations[1],
152+
pads[0],
153+
pads[1],
154+
pads[2],
155+
pads[3],
156+
strides[0],
157+
strides[1],
158+
col_buffer_data,
159+
input_offset);
160+
}
138161
} else {
139-
math::Im2col<uint8_t, StorageOrder::NCHW>()(
140-
Xdata,
141-
input_shape.GetDims().data(),
142-
output_shape.GetDims().data(),
143-
kernel_dim,
144-
kernel_shape.data(),
145-
strides.data(),
146-
dilations.data(),
147-
pads.data(),
148-
static_cast<int>(kernel_rank),
149-
col_buffer_data,
150-
false,
151-
input_offset);
162+
if (X_is_signed) {
163+
math::Im2col<int8_t, StorageOrder::NCHW>()(
164+
reinterpret_cast<const int8_t*>(Xdata),
165+
input_shape.GetDims().data(),
166+
output_shape.GetDims().data(),
167+
kernel_dim,
168+
kernel_shape.data(),
169+
strides.data(),
170+
dilations.data(),
171+
pads.data(),
172+
static_cast<int>(kernel_rank),
173+
reinterpret_cast<int8_t*>(col_buffer_data),
174+
false,
175+
static_cast<int8_t>(input_offset));
176+
} else {
177+
math::Im2col<uint8_t, StorageOrder::NCHW>()(
178+
Xdata,
179+
input_shape.GetDims().data(),
180+
output_shape.GetDims().data(),
181+
kernel_dim,
182+
kernel_shape.data(),
183+
strides.data(),
184+
dilations.data(),
185+
pads.data(),
186+
static_cast<int>(kernel_rank),
187+
col_buffer_data,
188+
false,
189+
input_offset);
190+
}
152191
}
153192
}
154193

155194
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
156195
gemm_shape.M = static_cast<size_t>(M / conv_attrs_.group);
157196
gemm_shape.N = static_cast<size_t>(output_image_size);
158197
gemm_shape.K = static_cast<size_t>(kernel_dim);
198+
gemm_shape.AIsSigned = W->IsDataType<int8_t>();
199+
gemm_shape.BIsSigned = X_is_signed;
159200

160201
MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
161202
gemm_params.A = Wdata + group_id * W_offset;
162203
gemm_params.lda = static_cast<size_t>(kernel_dim);
163204
gemm_params.ZeroPointA = filter_offset;
164-
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
205+
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data;
165206
gemm_params.ldb = static_cast<size_t>(output_image_size);
166207
gemm_params.ZeroPointB = &input_offset;
167208
gemm_params.C = Ydata;

onnxruntime/core/util/math_cpu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ void Im2col<T, StorageOrder::NCHW>::operator()(
527527

528528
template struct Im2col<float, StorageOrder::NCHW>;
529529
template struct Im2col<uint8_t, StorageOrder::NCHW>;
530+
template struct Im2col<int8_t, StorageOrder::NCHW>;
530531

531532
template <typename T>
532533
void Im2col<T, StorageOrder::NHWC>::operator()(

0 commit comments

Comments
 (0)