Skip to content

Commit 35c9d88

Browse files
committed
gptq_4bit
qgemm_4bit cpu support, slow! in_features x bit support fix windows minor fix, weight index rebase conflict
1 parent 6ada97c commit 35c9d88

File tree

11 files changed

+1224
-0
lines changed

11 files changed

+1224
-0
lines changed

onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
3030
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
3131
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
3232
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
33+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuantNbitsGemm);
34+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DequantizeAndUnpackWeight);
3335
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding);
3436
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
3537
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
@@ -302,6 +304,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
302304
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
303305
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
304306
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>,
307+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuantNbitsGemm)>,
308+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DequantizeAndUnpackWeight)>,
305309
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding)>,
306310
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND)>,
307311
#if !defined(DISABLE_SPARSE_TENSORS)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <cstdint>
5+
#include <cstdio>
6+
#include "core/common/common.h"
7+
#include "core/framework/op_kernel.h"
8+
#include "core/framework/tensor.h"
9+
#include "core/framework/tensorprotoutils.h"
10+
11+
12+
namespace onnxruntime {
13+
namespace contrib {
14+
15+
class DequantizeAndUnpackWeight final : public OpKernel {
16+
public:
17+
explicit DequantizeAndUnpackWeight(const OpKernelInfo& info) : OpKernel{info} {
18+
ORT_ENFORCE(info.GetAttr<int64_t>("bits", &bits_).IsOK());
19+
ORT_ENFORCE(info.GetAttr<int64_t>("groupsize", &groupsize_).IsOK());
20+
in_features_ = info.GetAttrOrDefault<int64_t>("in_features", -1);
21+
22+
ORT_ENFORCE(bits_ > 1 && bits_ < 9, "bits must be in range [2, 8]");
23+
if (bits_ != 2 && bits_ != 4 && bits_ != 8 && in_features_ == -1) {
24+
ORT_THROW("in_features must be specified for bits other than 2, 4, 8");
25+
}
26+
if (in_features_ == -1) {
27+
const auto& node{Node()};
28+
const auto& input_defs = node.InputDefs();
29+
const NodeArg& X = *input_defs[0];
30+
auto X_shape = utils::GetTensorShapeFromTensorShapeProto(*X.Shape());
31+
in_features_ = X_shape[0] * (32 / bits_);
32+
}
33+
}
34+
35+
Status Compute(OpKernelContext* context) const override;
36+
37+
private:
38+
template <typename T>
39+
struct ComputeImpl;
40+
41+
int64_t bits_;
42+
int64_t groupsize_;
43+
int64_t in_features_;
44+
};
45+
46+
ONNX_OPERATOR_KERNEL_EX(
47+
DequantizeAndUnpackWeight,
48+
kMSDomain,
49+
1,
50+
kCpuExecutionProvider,
51+
(*KernelDefBuilder::Create())
52+
.TypeConstraint("T", BuildKernelDefConstraints<uint32_t, int32_t>()),
53+
DequantizeAndUnpackWeight);
54+
55+
void DequantNbitWeight(OpKernelContext* ctx, const Tensor* input_weight, Tensor* output, const Tensor* input_zeros,
56+
const Tensor* input_scale, const int64_t bits_, const int64_t compress_ratio,
57+
const int64_t groupsize_);
58+
59+
Status DequantizeAndUnpackWeight::Compute(OpKernelContext* ctx) const {
60+
const auto* input_weight = ctx->Input<Tensor>(0);
61+
const auto* input_scale = ctx->Input<Tensor>(1);
62+
const auto* input_zeros = ctx->Input<Tensor>(2);
63+
// const auto* input_gidx = ctx->Input<Tensor>(5);
64+
const auto& qweight_shape = input_weight->Shape();
65+
const int64_t compress_ratio = sizeof(int32_t)*8 / bits_;
66+
TensorShape output_shape = qweight_shape;
67+
output_shape[0] = output_shape[0] * compress_ratio;
68+
auto* output = ctx->Output(0, output_shape);
69+
DequantNbitWeight(ctx, input_weight, output, input_zeros, input_scale, bits_, compress_ratio, groupsize_);
70+
71+
return Status::OK();
72+
}
73+
74+
} // namespace contrib
75+
} // namespace onnxruntime
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "core/framework/float16.h"
2+
#include "core/platform/threadpool.h"
3+
#include "core/common/common.h"
4+
#include "core/framework/op_kernel.h"
5+
#include "core/framework/tensor.h"
6+
7+
namespace onnxruntime {
8+
namespace contrib {
9+
10+
void DequantNbitWeight(OpKernelContext* ctx, const Tensor* input_weight, Tensor* output, const Tensor* input_zeros,
11+
const Tensor* input_scale, const int64_t bits_, const int64_t compress_ratio,
12+
const int64_t groupsize_) {
13+
if(ctx)return;
14+
const auto& qweight_shape = input_weight->Shape();
15+
const uint32_t* u32_in = reinterpret_cast<const uint32_t*>(input_weight->Data<int32_t>());
16+
float* f32_out = output->MutableData<float>();
17+
const uint32_t* u32_zeros = reinterpret_cast<const uint32_t*>(input_zeros->Data<int32_t>());
18+
const MLFloat16* f16_scale = input_scale->Data<MLFloat16>();
19+
20+
int64_t task_count = qweight_shape[0];
21+
// for (int64_t mi = 0; mi < qweight_shape[0]; mi++) {
22+
concurrency::ThreadPool::TryBatchParallelFor(
23+
ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
24+
[&](ptrdiff_t task_idx) {
25+
int64_t mi = task_idx;
26+
for (int64_t ki = 0; ki < qweight_shape[1]; ki++) {
27+
uint32_t u32_weight = u32_in[mi * qweight_shape[1] + ki];
28+
uint32_t u32_zero = u32_zeros[mi / groupsize_ * qweight_shape[1] / compress_ratio + ki / compress_ratio];
29+
uint8_t u8_zero = (u32_zero >> (ki / compress_ratio)) & 0xF;
30+
float f32_scale_val = (f16_scale[mi / groupsize_ * qweight_shape[1] + ki]).ToFloat();
31+
float scale_zero = f32_scale_val * (u8_zero);
32+
for (int64_t w_idx = 0; w_idx < compress_ratio; w_idx++) {
33+
f32_out[(mi + w_idx) * qweight_shape[1] + ki] = (u32_weight & 0xF) * (f32_scale_val)-scale_zero;
34+
u32_weight = u32_weight >> bits_;
35+
}
36+
}
37+
},
38+
0);
39+
}
40+
41+
} // namespace contrib
42+
} // namespace onnxruntime
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <cstdio>
5+
#include "core/common/common.h"
6+
#include "core/framework/op_kernel.h"
7+
#include "core/framework/tensor.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
12+
class QuantNbitsGemm final : public OpKernel {
13+
public:
14+
explicit QuantNbitsGemm(const OpKernelInfo& info) : OpKernel{info} {
15+
//ORT_ENFORCE(info.GetAttr("outfeatures", &outfeatures_).IsOK());
16+
//ORT_ENFORCE(info.GetAttr("infeatures", &in_features_).IsOK());
17+
bits_ = info.GetAttrOrDefault<int64_t>("bits", 3);
18+
groupsize_ = info.GetAttrOrDefault<int64_t>("groupsize", 128);
19+
}
20+
21+
Status Compute(OpKernelContext* context) const override;
22+
23+
private:
24+
25+
template <typename T>
26+
struct ComputeImpl;
27+
28+
int64_t outfeatures_;
29+
int64_t in_features_;
30+
int64_t bits_;
31+
int64_t groupsize_;
32+
};
33+
34+
ONNX_OPERATOR_KERNEL_EX(
35+
QuantNbitsGemm,
36+
kMSDomain,
37+
1,
38+
kCpuExecutionProvider,
39+
(*KernelDefBuilder::Create())
40+
.TypeConstraint("T", BuildKernelDefConstraints<float, MLFloat16>()),
41+
QuantNbitsGemm);
42+
43+
44+
Status QuantNbitsGemm::Compute(OpKernelContext* ctx) const {
45+
const auto* input_x = ctx->Input<Tensor>(0);
46+
const auto* input_weight = ctx->Input<Tensor>(1);
47+
//const auto* input_scale = ctx->Input<Tensor>(2);
48+
const auto* input_zeros = ctx->Input<Tensor>(3);
49+
//const auto* input_bias = ctx->Input<Tensor>(4);
50+
//const auto* input_gidx = ctx->Input<Tensor>(5);
51+
const auto& input_shape = input_x->Shape();
52+
const auto& weight_shape = input_weight->Shape();
53+
TensorShapeVector output_shape = input_shape.AsShapeVector();
54+
output_shape[output_shape.size() - 1] = weight_shape[1];
55+
auto* output = ctx->Output(0, output_shape);
56+
auto batch = input_shape[0] * (input_shape.NumDimensions() > 2 ? input_shape[1] : 1);
57+
//int64_t in_features = input_shape[input_shape.NumDimensions() - 1];
58+
input_x->Data<MLFloat16>();
59+
//auto *outp=output->Data<MLFloat16>();
60+
//input_scale->Data<MLFloat16>();
61+
printf("%zu,%zu\n", batch, output->Shape()[1]);
62+
63+
size_t sz = weight_shape[0] * weight_shape[1]*2;
64+
std::vector<int32_t> buf(sz);
65+
printf("%d...%d,", input_weight->Data<int32_t>()[0], input_zeros->Data<int32_t>()[0]);
66+
67+
68+
return Status::OK();
69+
}
70+
71+
} // namespace contrib
72+
} // namespace onnxruntime

onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulNBits);
144144
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4);
145145
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4);
146146
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4);
147+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, QuantNbitsGemm);
148+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, DequantizeAndUnpackWeight);
147149
class CUDA_MS_OP_CLASS_NAME(1, Trilu);
148150
class CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor);
149151
class CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping);
@@ -348,6 +350,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
348350
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4)>,
349351
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4)>,
350352
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4)>,
353+
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QuantNbitsGemm)>,
354+
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DequantizeAndUnpackWeight)>,
351355
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasSoftmax)>,
352356
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasDropout)>,
353357
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BitmaskDropout)>,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/cuda/cuda_kernel.h"
5+
#include "core/framework/tensorprotoutils.h"
6+
7+
namespace onnxruntime {
8+
namespace contrib {
9+
namespace cuda {
10+
11+
class DequantizeAndUnpackWeight final : public ::onnxruntime::cuda::CudaKernel {
12+
public:
13+
explicit DequantizeAndUnpackWeight(const OpKernelInfo& info) : CudaKernel{info} {
14+
ORT_ENFORCE(info.GetAttr<int64_t>("bits", &bits_).IsOK());
15+
ORT_ENFORCE(info.GetAttr<int64_t>("groupsize", &group_size_).IsOK());
16+
in_features_ = info.GetAttrOrDefault<int64_t>("in_features", -1);
17+
18+
ORT_ENFORCE(bits_ > 1 && bits_ < 9, "bits must be in range [2, 8]");
19+
if (bits_ != 2 && bits_ != 4 && bits_ != 8 && in_features_ == -1) {
20+
ORT_THROW("in_features must be specified for bits other than 2, 4, 8");
21+
}
22+
if (in_features_ == -1) {
23+
const auto& node{Node()};
24+
const auto& input_defs = node.InputDefs();
25+
const NodeArg& X = *input_defs[0];
26+
in_features_ = X.Shape()->dim(0).dim_value() * (32 / bits_);
27+
}
28+
}
29+
30+
Status ComputeInternal(OpKernelContext* context) const override;
31+
32+
private:
33+
using Base = CudaKernel;
34+
int64_t bits_;
35+
int64_t group_size_;
36+
int64_t in_features_;
37+
};
38+
39+
ONNX_OPERATOR_KERNEL_EX(
40+
DequantizeAndUnpackWeight,
41+
kMSDomain,
42+
1,
43+
kCudaExecutionProvider,
44+
(*KernelDefBuilder::Create())
45+
.TypeConstraint("T", BuildKernelDefConstraints<int32_t>())
46+
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16>()),
47+
DequantizeAndUnpackWeight);
48+
49+
void DequantWeightNbit(
50+
cudaStream_t stream,
51+
const int32_t* qweight_i32,
52+
const void* scales_data,
53+
const int32_t* zeros_data,
54+
void* weight_out,
55+
uint32_t MATRIX_K,
56+
uint32_t MATRIX_N,
57+
uint32_t bits,
58+
uint32_t groupsize);
59+
60+
Status DequantizeAndUnpackWeight::ComputeInternal(OpKernelContext* ctx) const {
61+
const auto* qweight = ctx->Input<Tensor>(0);
62+
const auto* input_scale = ctx->Input<Tensor>(1);
63+
const auto* input_zeros = ctx->Input<Tensor>(2);
64+
65+
auto output_shape = qweight->Shape();
66+
output_shape[0] = in_features_;
67+
68+
auto* output = ctx->Output(0, output_shape);
69+
DequantWeightNbit(Stream(ctx), qweight->Data<int32_t>(),
70+
input_scale->Data<MLFloat16>(),
71+
input_zeros->Data<int32_t>(),
72+
output->MutableData<MLFloat16>(),
73+
in_features_, output_shape[1], bits_, group_size_);
74+
return Status::OK();
75+
}
76+
77+
} // namespace cuda
78+
} // namespace contrib
79+
} // namespace onnxruntime

0 commit comments

Comments
 (0)