|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include <cstdint> |
| 5 | +#include <cstdio> |
| 6 | +#include "core/common/common.h" |
| 7 | +#include "core/framework/op_kernel.h" |
| 8 | +#include "core/framework/tensor.h" |
| 9 | +#include "core/framework/tensorprotoutils.h" |
| 10 | + |
| 11 | + |
| 12 | +namespace onnxruntime { |
| 13 | +namespace contrib { |
| 14 | + |
| 15 | +class DequantizeAndUnpackWeight final : public OpKernel { |
| 16 | + public: |
| 17 | + explicit DequantizeAndUnpackWeight(const OpKernelInfo& info) : OpKernel{info} { |
| 18 | + ORT_ENFORCE(info.GetAttr<int64_t>("bits", &bits_).IsOK()); |
| 19 | + ORT_ENFORCE(info.GetAttr<int64_t>("groupsize", &groupsize_).IsOK()); |
| 20 | + in_features_ = info.GetAttrOrDefault<int64_t>("in_features", -1); |
| 21 | + |
| 22 | + ORT_ENFORCE(bits_ > 1 && bits_ < 9, "bits must be in range [2, 8]"); |
| 23 | + if (bits_ != 2 && bits_ != 4 && bits_ != 8 && in_features_ == -1) { |
| 24 | + ORT_THROW("in_features must be specified for bits other than 2, 4, 8"); |
| 25 | + } |
| 26 | + if (in_features_ == -1) { |
| 27 | + const auto& node{Node()}; |
| 28 | + const auto& input_defs = node.InputDefs(); |
| 29 | + const NodeArg& X = *input_defs[0]; |
| 30 | + auto X_shape = utils::GetTensorShapeFromTensorShapeProto(*X.Shape()); |
| 31 | + in_features_ = X_shape[0] * (32 / bits_); |
| 32 | + } |
| 33 | + } |
| 34 | + |
| 35 | + Status Compute(OpKernelContext* context) const override; |
| 36 | + |
| 37 | + private: |
| 38 | + template <typename T> |
| 39 | + struct ComputeImpl; |
| 40 | + |
| 41 | + int64_t bits_; |
| 42 | + int64_t groupsize_; |
| 43 | + int64_t in_features_; |
| 44 | +}; |
| 45 | + |
| 46 | +ONNX_OPERATOR_KERNEL_EX( |
| 47 | + DequantizeAndUnpackWeight, |
| 48 | + kMSDomain, |
| 49 | + 1, |
| 50 | + kCpuExecutionProvider, |
| 51 | + (*KernelDefBuilder::Create()) |
| 52 | + .TypeConstraint("T", BuildKernelDefConstraints<uint32_t, int32_t>()), |
| 53 | + DequantizeAndUnpackWeight); |
| 54 | + |
| 55 | +void DequantNbitWeight(OpKernelContext* ctx, const Tensor* input_weight, Tensor* output, const Tensor* input_zeros, |
| 56 | + const Tensor* input_scale, const int64_t bits_, const int64_t compress_ratio, |
| 57 | + const int64_t groupsize_); |
| 58 | + |
| 59 | +Status DequantizeAndUnpackWeight::Compute(OpKernelContext* ctx) const { |
| 60 | + const auto* input_weight = ctx->Input<Tensor>(0); |
| 61 | + const auto* input_scale = ctx->Input<Tensor>(1); |
| 62 | + const auto* input_zeros = ctx->Input<Tensor>(2); |
| 63 | + // const auto* input_gidx = ctx->Input<Tensor>(5); |
| 64 | + const auto& qweight_shape = input_weight->Shape(); |
| 65 | + const int64_t compress_ratio = sizeof(int32_t)*8 / bits_; |
| 66 | + TensorShape output_shape = qweight_shape; |
| 67 | + output_shape[0] = output_shape[0] * compress_ratio; |
| 68 | + auto* output = ctx->Output(0, output_shape); |
| 69 | + DequantNbitWeight(ctx, input_weight, output, input_zeros, input_scale, bits_, compress_ratio, groupsize_); |
| 70 | + |
| 71 | + return Status::OK(); |
| 72 | +} |
| 73 | + |
| 74 | +} // namespace contrib |
| 75 | +} // namespace onnxruntime |
0 commit comments