@@ -23,165 +23,168 @@ class ConvInteger : public OpKernel {
2323
2424 private:
2525 template <typename XT, typename WT>
26- Status ComputeInner (OpKernelContext* context) const {
27- const auto input_defs = Node ().InputDefs ();
28- size_t num_inputs = input_defs.size ();
29- const auto * X = context->Input <Tensor>(0 );
30- const auto * W = context->Input <Tensor>(1 );
31- uint8_t input_offset = 0 ;
32- uint8_t filter_offset = 0 ;
33- if (num_inputs >= 3 && input_defs[2 ]->Exists ()) {
34- const auto * X_Zero_Point = context->Input <Tensor>(2 );
35- ORT_ENFORCE (IsScalarOr1ElementVector (X_Zero_Point), " Must be a scalar or 1D tensor or size 1." );
36- input_offset = *static_cast <const uint8_t *>(X_Zero_Point->DataRaw ());
37- }
38- if (num_inputs >= 4 && input_defs[3 ]->Exists ()) {
39- const auto * W_Zero_Point = context->Input <Tensor>(3 );
40- ORT_ENFORCE (IsScalarOr1ElementVector (W_Zero_Point), " Non per-tensor quantization is not supported now." );
41- filter_offset = *static_cast <const uint8_t *>(W_Zero_Point->DataRaw ());
42- }
26+ Status ComputeInner (OpKernelContext* context) const ;
27+ };
4328
44- const int64_t N = X->Shape ()[0 ];
45- const int64_t C = X->Shape ()[1 ];
46- const int64_t M = W->Shape ()[0 ];
47- ORT_RETURN_IF_ERROR (conv_attrs_.ValidateInputShape (X, W));
29+ ONNX_OPERATOR_KERNEL_EX (
30+ ConvInteger,
31+ kOnnxDomain ,
32+ 10 ,
33+ kCpuExecutionProvider ,
34+ KernelDefBuilder ()
35+ .TypeConstraint(" T1" , {DataTypeImpl::GetTensorType<uint8_t >(),
36+ DataTypeImpl::GetTensorType<int8_t >()})
37+ .TypeConstraint(" T2" , {DataTypeImpl::GetTensorType<uint8_t >(),
38+ DataTypeImpl::GetTensorType<int8_t >()})
39+ .TypeConstraint(" T3" , DataTypeImpl::GetTensorType<int32_t >()),
40+ ConvInteger);
4841
49- TensorShapeVector kernel_shape;
50- ORT_RETURN_IF_ERROR (conv_attrs_.ComputeKernelShape (W->Shape (), kernel_shape));
42+ template <typename XT, typename WT>
43+ Status ConvInteger::ComputeInner (OpKernelContext* context) const {
44+ const auto input_defs = Node ().InputDefs ();
45+ size_t num_inputs = input_defs.size ();
46+ const auto * X = context->Input <Tensor>(0 );
47+ const auto * W = context->Input <Tensor>(1 );
48+ uint8_t input_offset = 0 ;
49+ uint8_t filter_offset = 0 ;
50+ if (num_inputs >= 3 && input_defs[2 ]->Exists ()) {
51+ const auto * X_Zero_Point = context->Input <Tensor>(2 );
52+ ORT_ENFORCE (IsScalarOr1ElementVector (X_Zero_Point), " Must be a scalar or 1D tensor or size 1." );
53+ input_offset = *static_cast <const uint8_t *>(X_Zero_Point->DataRaw ());
54+ }
55+ if (num_inputs >= 4 && input_defs[3 ]->Exists ()) {
56+ const auto * W_Zero_Point = context->Input <Tensor>(3 );
57+ ORT_ENFORCE (IsScalarOr1ElementVector (W_Zero_Point), " Non per-tensor quantization is not supported now." );
58+ filter_offset = *static_cast <const uint8_t *>(W_Zero_Point->DataRaw ());
59+ }
5160
52- ConvPadVector pads (conv_attrs_.pads );
53- if (pads.empty ()) {
54- pads.resize (kernel_shape.size () * 2 , 0 );
55- }
56- TensorShapeVector dilations (conv_attrs_.dilations );
57- if (dilations.empty ()) {
58- dilations.resize (kernel_shape.size (), 1 );
59- }
60- TensorShapeVector strides (conv_attrs_.strides );
61- if (strides.empty ()) {
62- strides.resize (kernel_shape.size (), 1 );
63- }
61+ const int64_t N = X->Shape ()[0 ];
62+ const int64_t C = X->Shape ()[1 ];
63+ const int64_t M = W->Shape ()[0 ];
64+ ORT_RETURN_IF_ERROR (conv_attrs_.ValidateInputShape (X, W));
6465
65- TensorShapeVector Y_dims ({N, M});
66- TensorShape input_shape = X->Shape ().Slice (2 );
67- ORT_RETURN_IF_ERROR (conv_attrs_.InferPadsAndOutputShape (input_shape, kernel_shape, strides, dilations, pads, Y_dims));
68- Tensor* Y = context->Output (0 , TensorShape (Y_dims));
69- TensorShape output_shape = Y->Shape ().Slice (2 );
66+ TensorShapeVector kernel_shape;
67+ ORT_RETURN_IF_ERROR (conv_attrs_.ComputeKernelShape (W->Shape (), kernel_shape));
7068
71- // Bail out early if one of the dimensions is zero.
72- if (Y->Shape ().Size () == 0 ) {
73- return Status::OK ();
74- }
69+ ConvPadVector pads (conv_attrs_.pads );
70+ if (pads.empty ()) {
71+ pads.resize (kernel_shape.size () * 2 , 0 );
72+ }
73+ TensorShapeVector dilations (conv_attrs_.dilations );
74+ if (dilations.empty ()) {
75+ dilations.resize (kernel_shape.size (), 1 );
76+ }
77+ TensorShapeVector strides (conv_attrs_.strides );
78+ if (strides.empty ()) {
79+ strides.resize (kernel_shape.size (), 1 );
80+ }
7581
76- const int64_t input_image_size = input_shape.Size ();
77- const int64_t output_image_size = output_shape.Size ();
78- const int64_t kernel_size = TensorShape (kernel_shape).Size ();
79- const int64_t X_offset = C / conv_attrs_.group * input_image_size;
80- const int64_t Y_offset = Y->Shape ().Size () / Y->Shape ()[0 ] / conv_attrs_.group ;
81- const int64_t W_offset = W->Shape ().Size () / conv_attrs_.group ;
82- const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
83- const int64_t col_buffer_size = kernel_dim * output_image_size;
82+ TensorShapeVector Y_dims ({N, M});
83+ TensorShape input_shape = X->Shape ().Slice (2 );
84+ ORT_RETURN_IF_ERROR (conv_attrs_.InferPadsAndOutputShape (input_shape, kernel_shape, strides, dilations, pads, Y_dims));
85+ Tensor* Y = context->Output (0 , TensorShape (Y_dims));
86+ TensorShape output_shape = Y->Shape ().Slice (2 );
8487
85- const size_t kernel_rank = kernel_shape.size ();
88+ // Bail out early if one of the dimensions is zero.
89+ if (Y->Shape ().Size () == 0 ) {
90+ return Status::OK ();
91+ }
8692
87- BufferUniquePtr col_buffer;
93+ const int64_t input_image_size = input_shape.Size ();
94+ const int64_t output_image_size = output_shape.Size ();
95+ const int64_t kernel_size = TensorShape (kernel_shape).Size ();
96+ const int64_t X_offset = C / conv_attrs_.group * input_image_size;
97+ const int64_t Y_offset = Y->Shape ().Size () / Y->Shape ()[0 ] / conv_attrs_.group ;
98+ const int64_t W_offset = W->Shape ().Size () / conv_attrs_.group ;
99+ const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
100+ const int64_t col_buffer_size = kernel_dim * output_image_size;
88101
89- // Pointwise convolutions can use the original input tensor in place,
90- // otherwise a temporary buffer is required for the im2col transform.
91- if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding ()) {
92- AllocatorPtr alloc;
93- ORT_RETURN_IF_ERROR (context->GetTempSpaceAllocator (&alloc));
102+ const size_t kernel_rank = kernel_shape.size ();
94103
95- auto * col_data = alloc->Alloc (SafeInt<size_t >(sizeof (uint8_t )) * col_buffer_size);
96- col_buffer = BufferUniquePtr (col_data, BufferDeleter (std::move (alloc)));
97- }
104+ BufferUniquePtr col_buffer;
98105
99- auto * col_buffer_data = static_cast <uint8_t *>(col_buffer.get ());
100-
101- concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool ();
102-
103- const auto * Xdata = static_cast <const uint8_t *>(X->DataRaw ());
104- const auto * Wdata = static_cast <const uint8_t *>(W->DataRaw ());
105- auto * Ydata = Y->MutableData <int32_t >();
106-
107- for (int image_id = 0 ; image_id < N; ++image_id) {
108- for (int group_id = 0 ; group_id < conv_attrs_.group ; ++group_id) {
109- if (col_buffer_data != nullptr ) {
110- if (kernel_rank == 2 ) {
111- math::Im2col<XT, StorageOrder::NCHW>()(
112- reinterpret_cast <const XT*>(Xdata),
113- C / conv_attrs_.group ,
114- input_shape[0 ],
115- input_shape[1 ],
116- kernel_shape[0 ],
117- kernel_shape[1 ],
118- dilations[0 ],
119- dilations[1 ],
120- pads[0 ],
121- pads[1 ],
122- pads[2 ],
123- pads[3 ],
124- strides[0 ],
125- strides[1 ],
126- reinterpret_cast <XT*>(col_buffer_data),
127- static_cast <XT>(input_offset));
128- } else {
129- math::Im2col<XT, StorageOrder::NCHW>()(
130- reinterpret_cast <const XT*>(Xdata),
131- input_shape.GetDims ().data (),
132- output_shape.GetDims ().data (),
133- kernel_dim,
134- kernel_shape.data (),
135- strides.data (),
136- dilations.data (),
137- pads.data (),
138- static_cast <int >(kernel_rank),
139- reinterpret_cast <XT*>(col_buffer_data),
140- false ,
141- static_cast <XT>(input_offset));
142- }
143- }
106+ // Pointwise convolutions can use the original input tensor in place,
107+ // otherwise a temporary buffer is required for the im2col transform.
108+ if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding ()) {
109+ AllocatorPtr alloc;
110+ ORT_RETURN_IF_ERROR (context->GetTempSpaceAllocator (&alloc));
111+
112+ auto * col_data = alloc->Alloc (SafeInt<size_t >(sizeof (uint8_t )) * col_buffer_size);
113+ col_buffer = BufferUniquePtr (col_data, BufferDeleter (std::move (alloc)));
114+ }
144115
145- MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
146- gemm_shape.M = static_cast <size_t >(M / conv_attrs_.group );
147- gemm_shape.N = static_cast <size_t >(output_image_size);
148- gemm_shape.K = static_cast <size_t >(kernel_dim);
149- gemm_shape.AIsSigned = W->IsDataType <int8_t >();
150- gemm_shape.BIsSigned = X->IsDataType <int8_t >();
151-
152- MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
153- gemm_params.A = Wdata + group_id * W_offset;
154- gemm_params.lda = static_cast <size_t >(kernel_dim);
155- gemm_params.ZeroPointA = filter_offset;
156- gemm_params.B = (col_buffer_data == nullptr ) ? Xdata : col_buffer_data;
157- gemm_params.ldb = static_cast <size_t >(output_image_size);
158- gemm_params.ZeroPointB = &input_offset;
159- gemm_params.C = Ydata;
160- gemm_params.ldc = static_cast <size_t >(output_image_size);
161-
162- MlasGemm (gemm_shape, gemm_params, thread_pool);
163-
164- Xdata = reinterpret_cast <const uint8_t *>(X_offset + reinterpret_cast <const XT*>(Xdata));
165- Ydata += Y_offset;
116+ auto * col_buffer_data = static_cast <uint8_t *>(col_buffer.get ());
117+
118+ concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool ();
119+
120+ const auto * Xdata = static_cast <const uint8_t *>(X->DataRaw ());
121+ const auto * Wdata = static_cast <const uint8_t *>(W->DataRaw ());
122+ auto * Ydata = Y->MutableData <int32_t >();
123+
124+ for (int image_id = 0 ; image_id < N; ++image_id) {
125+ for (int group_id = 0 ; group_id < conv_attrs_.group ; ++group_id) {
126+ if (col_buffer_data != nullptr ) {
127+ if (kernel_rank == 2 ) {
128+ math::Im2col<XT, StorageOrder::NCHW>()(
129+ reinterpret_cast <const XT*>(Xdata),
130+ C / conv_attrs_.group ,
131+ input_shape[0 ],
132+ input_shape[1 ],
133+ kernel_shape[0 ],
134+ kernel_shape[1 ],
135+ dilations[0 ],
136+ dilations[1 ],
137+ pads[0 ],
138+ pads[1 ],
139+ pads[2 ],
140+ pads[3 ],
141+ strides[0 ],
142+ strides[1 ],
143+ reinterpret_cast <XT*>(col_buffer_data),
144+ static_cast <XT>(input_offset));
145+ } else {
146+ math::Im2col<XT, StorageOrder::NCHW>()(
147+ reinterpret_cast <const XT*>(Xdata),
148+ input_shape.GetDims ().data (),
149+ output_shape.GetDims ().data (),
150+ kernel_dim,
151+ kernel_shape.data (),
152+ strides.data (),
153+ dilations.data (),
154+ pads.data (),
155+ static_cast <int >(kernel_rank),
156+ reinterpret_cast <XT*>(col_buffer_data),
157+ false ,
158+ static_cast <XT>(input_offset));
159+ }
166160 }
167- }
168161
169- return Status::OK ();
162+ MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
163+ gemm_shape.M = static_cast <size_t >(M / conv_attrs_.group );
164+ gemm_shape.N = static_cast <size_t >(output_image_size);
165+ gemm_shape.K = static_cast <size_t >(kernel_dim);
166+ gemm_shape.AIsSigned = W->IsDataType <int8_t >();
167+ gemm_shape.BIsSigned = X->IsDataType <int8_t >();
168+
169+ MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
170+ gemm_params.A = Wdata + group_id * W_offset;
171+ gemm_params.lda = static_cast <size_t >(kernel_dim);
172+ gemm_params.ZeroPointA = filter_offset;
173+ gemm_params.B = (col_buffer_data == nullptr ) ? Xdata : col_buffer_data;
174+ gemm_params.ldb = static_cast <size_t >(output_image_size);
175+ gemm_params.ZeroPointB = &input_offset;
176+ gemm_params.C = Ydata;
177+ gemm_params.ldc = static_cast <size_t >(output_image_size);
178+
179+ MlasGemm (gemm_shape, gemm_params, thread_pool);
180+
181+ Xdata = reinterpret_cast <const uint8_t *>(X_offset + reinterpret_cast <const XT*>(Xdata));
182+ Ydata += Y_offset;
183+ }
170184 }
171- };
172185
173- ONNX_OPERATOR_KERNEL_EX (
174- ConvInteger,
175- kOnnxDomain ,
176- 10 ,
177- kCpuExecutionProvider ,
178- KernelDefBuilder ()
179- .TypeConstraint(" T1" , {DataTypeImpl::GetTensorType<uint8_t >(),
180- DataTypeImpl::GetTensorType<int8_t >()})
181- .TypeConstraint(" T2" , {DataTypeImpl::GetTensorType<uint8_t >(),
182- DataTypeImpl::GetTensorType<int8_t >()})
183- .TypeConstraint(" T3" , DataTypeImpl::GetTensorType<int32_t >()),
184- ConvInteger);
186+ return Status::OK ();
187+ }
185188
186189Status ConvInteger::Compute (OpKernelContext* context) const {
187190 const auto * X = context->Input <Tensor>(0 );
0 commit comments