@@ -155,6 +155,39 @@ void unpack_values(
155155 }
156156}
157157
158+ // Size in bytes of 1 packed weights column
159+ size_t inline packed_weights_size_per_n (
160+ int k,
161+ int group_size,
162+ int weight_nbit,
163+ bool has_weight_zeros,
164+ bool has_bias) {
165+ assert (k % group_size == 0 );
166+ int groups_per_col = k / group_size;
167+ int col_size = 0 ;
168+
169+ // qvals
170+ col_size += (k / 8 ) * weight_nbit;
171+
172+ // scales
173+ col_size += sizeof (float ) * groups_per_col;
174+
175+ // qvals_sum
176+ col_size += sizeof (int32_t ) * groups_per_col;
177+
178+ // zeros
179+ if (has_weight_zeros) {
180+ col_size += sizeof (int32_t ) * groups_per_col;
181+ }
182+
183+ // bias
184+ if (has_bias) {
185+ col_size += sizeof (float );
186+ }
187+
188+ return col_size;
189+ }
190+
158191} // namespace internal
159192
160193template <int weight_nbit, int nr, int kr, int sr>
@@ -281,46 +314,26 @@ size_t inline packed_weights_size(
281314 bool has_weight_zeros,
282315 bool has_bias,
283316 int nr) {
284- assert (k % group_size == 0 );
285- int groups_per_col = k / group_size;
286- int col_size = 0 ;
287-
288- // qvals
289- col_size += (k / 8 ) * weight_nbit;
290-
291- // scales
292- col_size += sizeof (float ) * groups_per_col;
293-
294- // qvals_sum
295- col_size += sizeof (int32_t ) * groups_per_col;
296-
297- // zeros
298- if (has_weight_zeros) {
299- col_size += sizeof (int32_t ) * groups_per_col;
300- }
301-
302- // bias
303- if (has_bias) {
304- col_size += sizeof (float );
305- }
317+ auto packed_weights_size_per_n = internal::packed_weights_size_per_n (
318+ k, group_size, weight_nbit, has_weight_zeros, has_bias);
306319
307320 // Replace n with next multiple of nr >= n
308321 n = ((n + nr - 1 ) / nr) * nr;
309-
310- return col_size * n;
322+ return packed_weights_size_per_n * n;
311323}
312324
313- // Unpack weights
325+ // Unpack weights at n_idx to support shared embedding/unembedding
314326template <int weight_nbit, int nr, int kr, int sr>
315- void unpack_weights (
327+ void unpack_weights_at_n_idx (
316328 // Output
317- int8_t * weight_qvals,
318- float * weight_scales,
329+ int8_t * weight_qvals, // k * nr values at n_idx
330+ float * weight_scales, // groups_per_k * nr values at n_idx
319331 // weight_zeros is not extracted if has_weight_zeros is false
320- int8_t * weight_zeros,
332+ int8_t * weight_zeros, // groups_per_k * nr values at n_idx
321333 // bias is not extracted if has_bias is false
322- float * bias,
334+ float * bias, // nr values at n_idx
323335 // Inputs
336+ int n_idx,
324337 int n,
325338 int k,
326339 int group_size,
@@ -329,6 +342,7 @@ void unpack_weights(
329342 void * packed_weights) {
330343 assert (k % group_size == 0 );
331344 assert (group_size % kr == 0 );
345+ assert (n_idx % nr == 0 );
332346
333347 int groups_per_k = k / group_size;
334348
@@ -344,72 +358,108 @@ void unpack_weights(
344358 constexpr int packed_buffer_bytes = weight_nbit * nr * kr / 8 ;
345359
346360 // Data pointer for packed weights
347- auto packed_weights_byte_ptr = (char *)packed_weights;
348-
349- // Loop over n by nr
350- for (int n_idx = 0 ; n_idx < n; n_idx += nr) {
351- // Look over groups along k
352- for (int group_idx = 0 ; group_idx < groups_per_k; group_idx++) {
353- // Loop over group by kr and pack the weights for the next nr columns
354- int k_idx = group_idx * group_size;
355- for (int idx_in_group = 0 ; idx_in_group < group_size;
356- idx_in_group += kr) {
357- // Unpack qvals
358- internal::unpack_buffer<weight_nbit, kr, nr>(
359- packed_values, packed_weights_byte_ptr);
360- packed_weights_byte_ptr += packed_buffer_bytes;
361- internal::unpack_values (buffer.data (), packed_values, nr, kr, sr);
362-
363- // Write weight_qvals
364- for (int j = 0 ; j < nr; j++) {
365- if (n_idx + j < n) {
366- std::memcpy (
367- weight_qvals + (n_idx + j) * k + (k_idx + idx_in_group),
368- buffer.data () + kr * j,
369- kr);
370- }
371- }
372-
373- } // loop over group (idx_in_group)
374-
375- // Write group scales and zeros for next nr columns
376-
377- // Write weight scales
361+ auto packed_weights_byte_ptr =
362+ ((char *)packed_weights +
363+ n_idx *
364+ internal::packed_weights_size_per_n (
365+ k, group_size, weight_nbit, has_weight_zeros, has_bias));
366+
367+ // Look over groups along k
368+ for (int group_idx = 0 ; group_idx < groups_per_k; group_idx++) {
369+ // Loop over group by kr and pack the weights for the next nr columns
370+ int k_idx = group_idx * group_size;
371+ for (int idx_in_group = 0 ; idx_in_group < group_size; idx_in_group += kr) {
372+ // Unpack qvals
373+ internal::unpack_buffer<weight_nbit, kr, nr>(
374+ packed_values, packed_weights_byte_ptr);
375+ packed_weights_byte_ptr += packed_buffer_bytes;
376+ internal::unpack_values (buffer.data (), packed_values, nr, kr, sr);
377+
378+ // Write weight_qvals
378379 for (int j = 0 ; j < nr; j++) {
379- float scale = *((float *)packed_weights_byte_ptr);
380- packed_weights_byte_ptr += sizeof (float );
381380 if (n_idx + j < n) {
382- weight_scales[(n_idx + j) * groups_per_k + group_idx] = scale;
381+ std::memcpy (
382+ weight_qvals + j * k + (k_idx + idx_in_group),
383+ buffer.data () + kr * j,
384+ kr);
383385 }
384386 }
385387
386- // Skip over weight qval sums
387- packed_weights_byte_ptr += nr * sizeof (int );
388+ } // loop over group (idx_in_group)
388389
389- // Write weight zeros
390- if (has_weight_zeros) {
391- for (int j = 0 ; j < nr; j++) {
392- int32_t zero = *((int32_t *)packed_weights_byte_ptr);
393- packed_weights_byte_ptr += sizeof (int32_t );
394- if (n_idx + j < n) {
395- weight_zeros[(n_idx + j) * groups_per_k + group_idx] = (int8_t )zero;
396- }
397- }
390+ // Write group scales and zeros for next nr columns
391+
392+ // Write weight scales
393+ for (int j = 0 ; j < nr; j++) {
394+ float scale = *((float *)packed_weights_byte_ptr);
395+ packed_weights_byte_ptr += sizeof (float );
396+ if (n_idx + j < n) {
397+ weight_scales[j * groups_per_k + group_idx] = scale;
398398 }
399+ }
399400
400- } // loop over k (group_idx)
401+ // Skip over weight qval sums
402+ packed_weights_byte_ptr += nr * sizeof (int );
401403
402- // Write bias
403- if (has_bias ) {
404+ // Write weight zeros
405+ if (has_weight_zeros ) {
404406 for (int j = 0 ; j < nr; j++) {
405- float bias_ = *((float *)packed_weights_byte_ptr);
406- packed_weights_byte_ptr += sizeof (float );
407+ int32_t zero = *((int32_t *)packed_weights_byte_ptr);
408+ packed_weights_byte_ptr += sizeof (int32_t );
407409 if (n_idx + j < n) {
408- bias[n_idx + j ] = bias_ ;
410+ weight_zeros[j * groups_per_k + group_idx ] = ( int8_t )zero ;
409411 }
410412 }
411413 }
412- } // n_idx
414+
415+ } // loop over k (group_idx)
416+
417+ // Write bias
418+ if (has_bias) {
419+ for (int j = 0 ; j < nr; j++) {
420+ float bias_ = *((float *)packed_weights_byte_ptr);
421+ packed_weights_byte_ptr += sizeof (float );
422+ if (n_idx + j < n) {
423+ bias[j] = bias_;
424+ }
425+ }
426+ }
427+ }
428+
429+ template <int weight_nbit, int nr, int kr, int sr>
430+ void unpack_weights (
431+ // Output
432+ int8_t * weight_qvals,
433+ float * weight_scales,
434+ // weight_zeros is not extracted if has_weight_zeros is false
435+ int8_t * weight_zeros,
436+ // bias is not extracted if has_bias is false
437+ float * bias,
438+ // Inputs
439+ int n,
440+ int k,
441+ int group_size,
442+ bool has_weight_zeros,
443+ bool has_bias,
444+ void * packed_weights) {
445+ assert (k % group_size == 0 );
446+ assert (group_size % kr == 0 );
447+ int groups_per_k = k / group_size;
448+
449+ for (int n_idx = 0 ; n_idx < n; n_idx += nr) {
450+ unpack_weights_at_n_idx<weight_nbit, nr, kr, sr>(
451+ weight_qvals + n_idx * k,
452+ weight_scales + n_idx * groups_per_k,
453+ weight_zeros + n_idx * groups_per_k,
454+ bias + n_idx,
455+ n_idx,
456+ n,
457+ k,
458+ group_size,
459+ has_weight_zeros,
460+ has_bias,
461+ packed_weights);
462+ }
413463}
414464
415465} // namespace torchao::kernels::cpu::aarch64::linear::packing
0 commit comments