Skip to content

Commit 76c4082

Browse files
committed
Refactor ConvInteger registration and fix lint warnings
1 parent 538ecc9 commit 76c4082

File tree

2 files changed

+158
-155
lines changed

2 files changed

+158
-155
lines changed

onnxruntime/core/providers/cpu/quantization/conv_integer.cc

Lines changed: 146 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -23,165 +23,168 @@ class ConvInteger : public OpKernel {
2323

2424
private:
2525
template <typename XT, typename WT>
26-
Status ComputeInner(OpKernelContext* context) const {
27-
const auto input_defs = Node().InputDefs();
28-
size_t num_inputs = input_defs.size();
29-
const auto* X = context->Input<Tensor>(0);
30-
const auto* W = context->Input<Tensor>(1);
31-
uint8_t input_offset = 0;
32-
uint8_t filter_offset = 0;
33-
if (num_inputs >= 3 && input_defs[2]->Exists()) {
34-
const auto* X_Zero_Point = context->Input<Tensor>(2);
35-
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
36-
input_offset = *static_cast<const uint8_t*>(X_Zero_Point->DataRaw());
37-
}
38-
if (num_inputs >= 4 && input_defs[3]->Exists()) {
39-
const auto* W_Zero_Point = context->Input<Tensor>(3);
40-
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
41-
filter_offset = *static_cast<const uint8_t*>(W_Zero_Point->DataRaw());
42-
}
26+
Status ComputeInner(OpKernelContext* context) const;
27+
};
4328

44-
const int64_t N = X->Shape()[0];
45-
const int64_t C = X->Shape()[1];
46-
const int64_t M = W->Shape()[0];
47-
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));
29+
ONNX_OPERATOR_KERNEL_EX(
30+
ConvInteger,
31+
kOnnxDomain,
32+
10,
33+
kCpuExecutionProvider,
34+
KernelDefBuilder()
35+
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<uint8_t>(),
36+
DataTypeImpl::GetTensorType<int8_t>()})
37+
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(),
38+
DataTypeImpl::GetTensorType<int8_t>()})
39+
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
40+
ConvInteger);
4841

49-
TensorShapeVector kernel_shape;
50-
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));
42+
template <typename XT, typename WT>
43+
Status ConvInteger::ComputeInner(OpKernelContext* context) const {
44+
const auto input_defs = Node().InputDefs();
45+
size_t num_inputs = input_defs.size();
46+
const auto* X = context->Input<Tensor>(0);
47+
const auto* W = context->Input<Tensor>(1);
48+
uint8_t input_offset = 0;
49+
uint8_t filter_offset = 0;
50+
if (num_inputs >= 3 && input_defs[2]->Exists()) {
51+
const auto* X_Zero_Point = context->Input<Tensor>(2);
52+
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
53+
input_offset = *static_cast<const uint8_t*>(X_Zero_Point->DataRaw());
54+
}
55+
if (num_inputs >= 4 && input_defs[3]->Exists()) {
56+
const auto* W_Zero_Point = context->Input<Tensor>(3);
57+
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
58+
filter_offset = *static_cast<const uint8_t*>(W_Zero_Point->DataRaw());
59+
}
5160

52-
ConvPadVector pads(conv_attrs_.pads);
53-
if (pads.empty()) {
54-
pads.resize(kernel_shape.size() * 2, 0);
55-
}
56-
TensorShapeVector dilations(conv_attrs_.dilations);
57-
if (dilations.empty()) {
58-
dilations.resize(kernel_shape.size(), 1);
59-
}
60-
TensorShapeVector strides(conv_attrs_.strides);
61-
if (strides.empty()) {
62-
strides.resize(kernel_shape.size(), 1);
63-
}
61+
const int64_t N = X->Shape()[0];
62+
const int64_t C = X->Shape()[1];
63+
const int64_t M = W->Shape()[0];
64+
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));
6465

65-
TensorShapeVector Y_dims({N, M});
66-
TensorShape input_shape = X->Shape().Slice(2);
67-
ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
68-
Tensor* Y = context->Output(0, TensorShape(Y_dims));
69-
TensorShape output_shape = Y->Shape().Slice(2);
66+
TensorShapeVector kernel_shape;
67+
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));
7068

71-
// Bail out early if one of the dimensions is zero.
72-
if (Y->Shape().Size() == 0) {
73-
return Status::OK();
74-
}
69+
ConvPadVector pads(conv_attrs_.pads);
70+
if (pads.empty()) {
71+
pads.resize(kernel_shape.size() * 2, 0);
72+
}
73+
TensorShapeVector dilations(conv_attrs_.dilations);
74+
if (dilations.empty()) {
75+
dilations.resize(kernel_shape.size(), 1);
76+
}
77+
TensorShapeVector strides(conv_attrs_.strides);
78+
if (strides.empty()) {
79+
strides.resize(kernel_shape.size(), 1);
80+
}
7581

76-
const int64_t input_image_size = input_shape.Size();
77-
const int64_t output_image_size = output_shape.Size();
78-
const int64_t kernel_size = TensorShape(kernel_shape).Size();
79-
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
80-
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
81-
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
82-
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
83-
const int64_t col_buffer_size = kernel_dim * output_image_size;
82+
TensorShapeVector Y_dims({N, M});
83+
TensorShape input_shape = X->Shape().Slice(2);
84+
ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
85+
Tensor* Y = context->Output(0, TensorShape(Y_dims));
86+
TensorShape output_shape = Y->Shape().Slice(2);
8487

85-
const size_t kernel_rank = kernel_shape.size();
88+
// Bail out early if one of the dimensions is zero.
89+
if (Y->Shape().Size() == 0) {
90+
return Status::OK();
91+
}
8692

87-
BufferUniquePtr col_buffer;
93+
const int64_t input_image_size = input_shape.Size();
94+
const int64_t output_image_size = output_shape.Size();
95+
const int64_t kernel_size = TensorShape(kernel_shape).Size();
96+
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
97+
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
98+
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
99+
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
100+
const int64_t col_buffer_size = kernel_dim * output_image_size;
88101

89-
// Pointwise convolutions can use the original input tensor in place,
90-
// otherwise a temporary buffer is required for the im2col transform.
91-
if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) {
92-
AllocatorPtr alloc;
93-
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
102+
const size_t kernel_rank = kernel_shape.size();
94103

95-
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * col_buffer_size);
96-
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
97-
}
104+
BufferUniquePtr col_buffer;
98105

99-
auto* col_buffer_data = static_cast<uint8_t*>(col_buffer.get());
100-
101-
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
102-
103-
const auto* Xdata = static_cast<const uint8_t*>(X->DataRaw());
104-
const auto* Wdata = static_cast<const uint8_t*>(W->DataRaw());
105-
auto* Ydata = Y->MutableData<int32_t>();
106-
107-
for (int image_id = 0; image_id < N; ++image_id) {
108-
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
109-
if (col_buffer_data != nullptr) {
110-
if (kernel_rank == 2) {
111-
math::Im2col<XT, StorageOrder::NCHW>()(
112-
reinterpret_cast<const XT*>(Xdata),
113-
C / conv_attrs_.group,
114-
input_shape[0],
115-
input_shape[1],
116-
kernel_shape[0],
117-
kernel_shape[1],
118-
dilations[0],
119-
dilations[1],
120-
pads[0],
121-
pads[1],
122-
pads[2],
123-
pads[3],
124-
strides[0],
125-
strides[1],
126-
reinterpret_cast<XT*>(col_buffer_data),
127-
static_cast<XT>(input_offset));
128-
} else {
129-
math::Im2col<XT, StorageOrder::NCHW>()(
130-
reinterpret_cast<const XT*>(Xdata),
131-
input_shape.GetDims().data(),
132-
output_shape.GetDims().data(),
133-
kernel_dim,
134-
kernel_shape.data(),
135-
strides.data(),
136-
dilations.data(),
137-
pads.data(),
138-
static_cast<int>(kernel_rank),
139-
reinterpret_cast<XT*>(col_buffer_data),
140-
false,
141-
static_cast<XT>(input_offset));
142-
}
143-
}
106+
// Pointwise convolutions can use the original input tensor in place,
107+
// otherwise a temporary buffer is required for the im2col transform.
108+
if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) {
109+
AllocatorPtr alloc;
110+
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
111+
112+
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * col_buffer_size);
113+
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
114+
}
144115

145-
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
146-
gemm_shape.M = static_cast<size_t>(M / conv_attrs_.group);
147-
gemm_shape.N = static_cast<size_t>(output_image_size);
148-
gemm_shape.K = static_cast<size_t>(kernel_dim);
149-
gemm_shape.AIsSigned = W->IsDataType<int8_t>();
150-
gemm_shape.BIsSigned = X->IsDataType<int8_t>();
151-
152-
MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
153-
gemm_params.A = Wdata + group_id * W_offset;
154-
gemm_params.lda = static_cast<size_t>(kernel_dim);
155-
gemm_params.ZeroPointA = filter_offset;
156-
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data;
157-
gemm_params.ldb = static_cast<size_t>(output_image_size);
158-
gemm_params.ZeroPointB = &input_offset;
159-
gemm_params.C = Ydata;
160-
gemm_params.ldc = static_cast<size_t>(output_image_size);
161-
162-
MlasGemm(gemm_shape, gemm_params, thread_pool);
163-
164-
Xdata = reinterpret_cast<const uint8_t*>(X_offset + reinterpret_cast<const XT*>(Xdata));
165-
Ydata += Y_offset;
116+
auto* col_buffer_data = static_cast<uint8_t*>(col_buffer.get());
117+
118+
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
119+
120+
const auto* Xdata = static_cast<const uint8_t*>(X->DataRaw());
121+
const auto* Wdata = static_cast<const uint8_t*>(W->DataRaw());
122+
auto* Ydata = Y->MutableData<int32_t>();
123+
124+
for (int image_id = 0; image_id < N; ++image_id) {
125+
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
126+
if (col_buffer_data != nullptr) {
127+
if (kernel_rank == 2) {
128+
math::Im2col<XT, StorageOrder::NCHW>()(
129+
reinterpret_cast<const XT*>(Xdata),
130+
C / conv_attrs_.group,
131+
input_shape[0],
132+
input_shape[1],
133+
kernel_shape[0],
134+
kernel_shape[1],
135+
dilations[0],
136+
dilations[1],
137+
pads[0],
138+
pads[1],
139+
pads[2],
140+
pads[3],
141+
strides[0],
142+
strides[1],
143+
reinterpret_cast<XT*>(col_buffer_data),
144+
static_cast<XT>(input_offset));
145+
} else {
146+
math::Im2col<XT, StorageOrder::NCHW>()(
147+
reinterpret_cast<const XT*>(Xdata),
148+
input_shape.GetDims().data(),
149+
output_shape.GetDims().data(),
150+
kernel_dim,
151+
kernel_shape.data(),
152+
strides.data(),
153+
dilations.data(),
154+
pads.data(),
155+
static_cast<int>(kernel_rank),
156+
reinterpret_cast<XT*>(col_buffer_data),
157+
false,
158+
static_cast<XT>(input_offset));
159+
}
166160
}
167-
}
168161

169-
return Status::OK();
162+
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
163+
gemm_shape.M = static_cast<size_t>(M / conv_attrs_.group);
164+
gemm_shape.N = static_cast<size_t>(output_image_size);
165+
gemm_shape.K = static_cast<size_t>(kernel_dim);
166+
gemm_shape.AIsSigned = W->IsDataType<int8_t>();
167+
gemm_shape.BIsSigned = X->IsDataType<int8_t>();
168+
169+
MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
170+
gemm_params.A = Wdata + group_id * W_offset;
171+
gemm_params.lda = static_cast<size_t>(kernel_dim);
172+
gemm_params.ZeroPointA = filter_offset;
173+
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data;
174+
gemm_params.ldb = static_cast<size_t>(output_image_size);
175+
gemm_params.ZeroPointB = &input_offset;
176+
gemm_params.C = Ydata;
177+
gemm_params.ldc = static_cast<size_t>(output_image_size);
178+
179+
MlasGemm(gemm_shape, gemm_params, thread_pool);
180+
181+
Xdata = reinterpret_cast<const uint8_t*>(X_offset + reinterpret_cast<const XT*>(Xdata));
182+
Ydata += Y_offset;
183+
}
170184
}
171-
};
172185

173-
ONNX_OPERATOR_KERNEL_EX(
174-
ConvInteger,
175-
kOnnxDomain,
176-
10,
177-
kCpuExecutionProvider,
178-
KernelDefBuilder()
179-
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<uint8_t>(),
180-
DataTypeImpl::GetTensorType<int8_t>()})
181-
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(),
182-
DataTypeImpl::GetTensorType<int8_t>()})
183-
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
184-
ConvInteger);
186+
return Status::OK();
187+
}
185188

186189
Status ConvInteger::Compute(OpKernelContext* context) const {
187190
const auto* X = context->Input<Tensor>(0);

onnxruntime/test/providers/cpu/nn/conv_integer_test.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,9 @@ TEST(ConvIntegerTest, WithoutPadding_2D_s8u8) {
678678

679679
std::vector<int64_t> x_dims{1, 1, 3, 3};
680680
test.AddInput<int8_t>("x", x_dims,
681-
{-1, 2, -3,
682-
4, -5, 6,
683-
-7, 8, -9});
681+
{-1, 2, -3,
682+
4, -5, 6,
683+
-7, 8, -9});
684684

685685
std::vector<int64_t> w_dims{1, 1, 2, 2};
686686
test.AddInput<uint8_t>("w", w_dims,
@@ -692,8 +692,8 @@ TEST(ConvIntegerTest, WithoutPadding_2D_s8u8) {
692692

693693
std::vector<int64_t> y_dims{1, 1, 2, 2};
694694
test.AddOutput<int32_t>("y", y_dims,
695-
{-5, 5,
696-
5, -5});
695+
{-5, 5,
696+
5, -5});
697697

698698
test.Run();
699699
}
@@ -703,9 +703,9 @@ TEST(ConvIntegerTest, WithPadding_2D_s8u8) {
703703

704704
std::vector<int64_t> x_dims{1, 1, 3, 3};
705705
test.AddInput<int8_t>("x", x_dims,
706-
{-1, 2, -3,
707-
4, -5, 6,
708-
-7, 8, -9});
706+
{-1, 2, -3,
707+
4, -5, 6,
708+
-7, 8, -9});
709709

710710
std::vector<int64_t> w_dims{1, 1, 2, 2};
711711
test.AddInput<uint8_t>("w", w_dims,
@@ -718,10 +718,10 @@ TEST(ConvIntegerTest, WithPadding_2D_s8u8) {
718718

719719
std::vector<int64_t> y_dims{1, 1, 4, 4};
720720
test.AddOutput<int32_t>("y", y_dims,
721-
{ -4, 5, -6, -9,
722-
14, -5, 5, 15,
723-
-20, 5, -5, -21,
724-
-14, 9, -10, -9});
721+
{-4, 5, -6, -9,
722+
14, -5, 5, 15,
723+
-20, 5, -5, -21,
724+
-14, 9, -10, -9});
725725

726726
test.Run();
727727
}

0 commit comments

Comments
 (0)