@@ -60,27 +60,47 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
6060
6161using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ;
6262
63- template <int mr, int kr, int sr>
64- size_t
65- activation_data_size (int m, int k, int group_size, bool has_weight_zeros) {
63+ size_t packed_activations_size (
64+ int m,
65+ int k,
66+ int group_size,
67+ bool has_weight_zeros,
68+ int mr,
69+ int kr,
70+ int sr) {
6671 (void )group_size; // unused
6772 (void )has_weight_zeros; // unused
6873 auto lhs_packing = get_lhs_packing ();
6974 return lhs_packing.get_lhs_packed_size (m, k, mr, kr, sr);
7075}
7176
72- template <int mr, int kr, int sr>
73- void prepare_activation_data (
74- void * activation_data,
77+ size_t packed_activations_offset (
78+ int m_idx,
79+ int k,
80+ int group_size,
81+ bool has_weight_zeros,
82+ int mr,
83+ int kr,
84+ int sr) {
85+ (void )group_size; // unused
86+ (void )has_weight_zeros; // unused
87+ auto lhs_pack = get_lhs_packing ();
88+ return lhs_pack.get_lhs_packed_offset (m_idx, k, mr, kr, sr);
89+ }
90+
91+ void pack_activations (
92+ void * packed_activations,
7593 int m,
7694 int k,
7795 int group_size,
7896 const float * activations,
79- bool has_weight_zeros) {
97+ bool has_weight_zeros,
98+ int mr,
99+ int kr,
100+ int sr) {
80101 (void )group_size; // unused
81102 (void )has_weight_zeros; // unused
82103 auto lhs_pack = get_lhs_packing ();
83-
84104 lhs_pack.run_lhs_pack (
85105 m,
86106 k,
@@ -90,33 +110,62 @@ void prepare_activation_data(
90110 /* m_index_start=*/ 0 ,
91111 activations,
92112 /* lhs_stride=*/ k * sizeof (float ),
93- activation_data );
113+ packed_activations );
94114}
95115
96- template <int nr, int kr, int sr>
97- size_t weight_data_size (
116+ size_t packed_weights_size (
98117 int n,
99118 int k,
100119 int group_size,
120+ int weight_nbit,
101121 bool has_weight_zeros,
102- bool has_bias) {
122+ bool has_bias,
123+ int nr,
124+ int kr,
125+ int sr) {
126+ (void )weight_nbit; // unused
103127 (void )has_weight_zeros; // unused
104128 (void )has_bias; // unused
105129 auto rhs_pack = get_rhs_packing ();
106130 return rhs_pack.get_rhs_packed_size (
107- n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
131+ internal::adjust_n (n),
132+ k,
133+ nr,
134+ kr,
135+ sr,
136+ group_size,
137+ kai_datatype::kai_dt_bf16);
138+ }
139+
140+ size_t packed_weights_offset (
141+ int n_idx,
142+ int k,
143+ int group_size,
144+ int weight_nbit,
145+ bool has_weight_zeros,
146+ bool has_bias,
147+ int nr,
148+ int kr,
149+ int sr) {
150+ (void )has_weight_zeros; // unused
151+ (void )has_bias; // unused
152+ auto rhs_pack = get_rhs_packing ();
153+ return rhs_pack.get_rhs_packed_offset (
154+ n_idx, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
108155}
109156
110- template <int nr, int kr, int sr>
111- void prepare_weight_data (
112- void * weight_data,
157+ void pack_weights (
158+ void * packed_weights,
113159 int n,
114160 int k,
115161 int group_size,
116162 const int8_t * weight_qvals,
117163 const float * weight_scales,
118164 const int8_t * weight_zeros,
119- const float * bias) {
165+ const float * bias,
166+ int nr,
167+ int kr,
168+ int sr) {
120169 if (group_size % 32 != 0 ) {
121170 throw std::runtime_error (
122171 " Group size must be a multiple of 32, but got group_size=" +
@@ -187,7 +236,7 @@ void prepare_weight_data(
187236 reinterpret_cast <const uint16_t *>(weight_scales_bf16_padded.data ()),
188237 /* scale_stride=*/ sizeof (uint16_t ) *
189238 (internal::roundup (k, group_size) / group_size),
190- /* rhs_packed=*/ weight_data ,
239+ /* rhs_packed=*/ packed_weights ,
191240 /* extra_bytes=*/ 0 ,
192241 /* qparams=*/ &qparams);
193242}
@@ -220,8 +269,8 @@ size_t get_preferred_alignement() {
220269 int n, \
221270 int k, \
222271 int group_size, \
223- const void * weight_data, \
224- const void * activation_data, \
272+ const void * packed_weights, \
273+ const void * packed_activations, \
225274 float clamp_min, \
226275 float clamp_max, \
227276 bool has_weight_zeros, \
@@ -235,11 +284,11 @@ size_t get_preferred_alignement() {
235284 } \
236285 get_ukernel ().run_matmul ( \
237286 m, \
238- internal::adjust_n (n), \
287+ n, \
239288 k, \
240289 group_size, \
241- activation_data, \
242- weight_data, \
290+ packed_activations, \
291+ packed_weights, \
243292 output, \
244293 /* dst_stride_row=*/ output_m_stride * sizeof (float ), \
245294 /* dst_stride_col=*/ sizeof (float ), \
0 commit comments