Skip to content

Commit ac267f8

Browse files
authored
Unpack weights at col
Differential Revision: D71170557 Pull Request resolved: #1933
1 parent 68d785c commit ac267f8

File tree

1 file changed

+132
-82
lines changed

1 file changed

+132
-82
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h

Lines changed: 132 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ void unpack_values(
155155
}
156156
}
157157

158+
// Size in bytes of 1 packed weights column
159+
size_t inline packed_weights_size_per_n(
160+
int k,
161+
int group_size,
162+
int weight_nbit,
163+
bool has_weight_zeros,
164+
bool has_bias) {
165+
assert(k % group_size == 0);
166+
int groups_per_col = k / group_size;
167+
int col_size = 0;
168+
169+
// qvals
170+
col_size += (k / 8) * weight_nbit;
171+
172+
// scales
173+
col_size += sizeof(float) * groups_per_col;
174+
175+
// qvals_sum
176+
col_size += sizeof(int32_t) * groups_per_col;
177+
178+
// zeros
179+
if (has_weight_zeros) {
180+
col_size += sizeof(int32_t) * groups_per_col;
181+
}
182+
183+
// bias
184+
if (has_bias) {
185+
col_size += sizeof(float);
186+
}
187+
188+
return col_size;
189+
}
190+
158191
} // namespace internal
159192

160193
template <int weight_nbit, int nr, int kr, int sr>
@@ -281,46 +314,26 @@ size_t inline packed_weights_size(
281314
bool has_weight_zeros,
282315
bool has_bias,
283316
int nr) {
284-
assert(k % group_size == 0);
285-
int groups_per_col = k / group_size;
286-
int col_size = 0;
287-
288-
// qvals
289-
col_size += (k / 8) * weight_nbit;
290-
291-
// scales
292-
col_size += sizeof(float) * groups_per_col;
293-
294-
// qvals_sum
295-
col_size += sizeof(int32_t) * groups_per_col;
296-
297-
// zeros
298-
if (has_weight_zeros) {
299-
col_size += sizeof(int32_t) * groups_per_col;
300-
}
301-
302-
// bias
303-
if (has_bias) {
304-
col_size += sizeof(float);
305-
}
317+
auto packed_weights_size_per_n = internal::packed_weights_size_per_n(
318+
k, group_size, weight_nbit, has_weight_zeros, has_bias);
306319

307320
// Replace n with next multiple of nr >= n
308321
n = ((n + nr - 1) / nr) * nr;
309-
310-
return col_size * n;
322+
return packed_weights_size_per_n * n;
311323
}
312324

313-
// Unpack weights
325+
// Unpack weights at n_idx to support shared embedding/unembedding
314326
template <int weight_nbit, int nr, int kr, int sr>
315-
void unpack_weights(
327+
void unpack_weights_at_n_idx(
316328
// Output
317-
int8_t* weight_qvals,
318-
float* weight_scales,
329+
int8_t* weight_qvals, // k * nr values at n_idx
330+
float* weight_scales, // groups_per_k * nr values at n_idx
319331
// weight_zeros is not extracted if has_weight_zeros is false
320-
int8_t* weight_zeros,
332+
int8_t* weight_zeros, // groups_per_k * nr values at n_idx
321333
// bias is not extracted if has_bias is false
322-
float* bias,
334+
float* bias, // nr values at n_idx
323335
// Inputs
336+
int n_idx,
324337
int n,
325338
int k,
326339
int group_size,
@@ -329,6 +342,7 @@ void unpack_weights(
329342
void* packed_weights) {
330343
assert(k % group_size == 0);
331344
assert(group_size % kr == 0);
345+
assert(n_idx % nr == 0);
332346

333347
int groups_per_k = k / group_size;
334348

@@ -344,72 +358,108 @@ void unpack_weights(
344358
constexpr int packed_buffer_bytes = weight_nbit * nr * kr / 8;
345359

346360
// Data pointer for packed weights
347-
auto packed_weights_byte_ptr = (char*)packed_weights;
348-
349-
// Loop over n by nr
350-
for (int n_idx = 0; n_idx < n; n_idx += nr) {
351-
// Look over groups along k
352-
for (int group_idx = 0; group_idx < groups_per_k; group_idx++) {
353-
// Loop over group by kr and pack the weights for the next nr columns
354-
int k_idx = group_idx * group_size;
355-
for (int idx_in_group = 0; idx_in_group < group_size;
356-
idx_in_group += kr) {
357-
// Unpack qvals
358-
internal::unpack_buffer<weight_nbit, kr, nr>(
359-
packed_values, packed_weights_byte_ptr);
360-
packed_weights_byte_ptr += packed_buffer_bytes;
361-
internal::unpack_values(buffer.data(), packed_values, nr, kr, sr);
362-
363-
// Write weight_qvals
364-
for (int j = 0; j < nr; j++) {
365-
if (n_idx + j < n) {
366-
std::memcpy(
367-
weight_qvals + (n_idx + j) * k + (k_idx + idx_in_group),
368-
buffer.data() + kr * j,
369-
kr);
370-
}
371-
}
372-
373-
} // loop over group (idx_in_group)
374-
375-
// Write group scales and zeros for next nr columns
376-
377-
// Write weight scales
361+
auto packed_weights_byte_ptr =
362+
((char*)packed_weights +
363+
n_idx *
364+
internal::packed_weights_size_per_n(
365+
k, group_size, weight_nbit, has_weight_zeros, has_bias));
366+
367+
// Look over groups along k
368+
for (int group_idx = 0; group_idx < groups_per_k; group_idx++) {
369+
// Loop over group by kr and pack the weights for the next nr columns
370+
int k_idx = group_idx * group_size;
371+
for (int idx_in_group = 0; idx_in_group < group_size; idx_in_group += kr) {
372+
// Unpack qvals
373+
internal::unpack_buffer<weight_nbit, kr, nr>(
374+
packed_values, packed_weights_byte_ptr);
375+
packed_weights_byte_ptr += packed_buffer_bytes;
376+
internal::unpack_values(buffer.data(), packed_values, nr, kr, sr);
377+
378+
// Write weight_qvals
378379
for (int j = 0; j < nr; j++) {
379-
float scale = *((float*)packed_weights_byte_ptr);
380-
packed_weights_byte_ptr += sizeof(float);
381380
if (n_idx + j < n) {
382-
weight_scales[(n_idx + j) * groups_per_k + group_idx] = scale;
381+
std::memcpy(
382+
weight_qvals + j * k + (k_idx + idx_in_group),
383+
buffer.data() + kr * j,
384+
kr);
383385
}
384386
}
385387

386-
// Skip over weight qval sums
387-
packed_weights_byte_ptr += nr * sizeof(int);
388+
} // loop over group (idx_in_group)
388389

389-
// Write weight zeros
390-
if (has_weight_zeros) {
391-
for (int j = 0; j < nr; j++) {
392-
int32_t zero = *((int32_t*)packed_weights_byte_ptr);
393-
packed_weights_byte_ptr += sizeof(int32_t);
394-
if (n_idx + j < n) {
395-
weight_zeros[(n_idx + j) * groups_per_k + group_idx] = (int8_t)zero;
396-
}
397-
}
390+
// Write group scales and zeros for next nr columns
391+
392+
// Write weight scales
393+
for (int j = 0; j < nr; j++) {
394+
float scale = *((float*)packed_weights_byte_ptr);
395+
packed_weights_byte_ptr += sizeof(float);
396+
if (n_idx + j < n) {
397+
weight_scales[j * groups_per_k + group_idx] = scale;
398398
}
399+
}
399400

400-
} // loop over k (group_idx)
401+
// Skip over weight qval sums
402+
packed_weights_byte_ptr += nr * sizeof(int);
401403

402-
// Write bias
403-
if (has_bias) {
404+
// Write weight zeros
405+
if (has_weight_zeros) {
404406
for (int j = 0; j < nr; j++) {
405-
float bias_ = *((float*)packed_weights_byte_ptr);
406-
packed_weights_byte_ptr += sizeof(float);
407+
int32_t zero = *((int32_t*)packed_weights_byte_ptr);
408+
packed_weights_byte_ptr += sizeof(int32_t);
407409
if (n_idx + j < n) {
408-
bias[n_idx + j] = bias_;
410+
weight_zeros[j * groups_per_k + group_idx] = (int8_t)zero;
409411
}
410412
}
411413
}
412-
} // n_idx
414+
415+
} // loop over k (group_idx)
416+
417+
// Write bias
418+
if (has_bias) {
419+
for (int j = 0; j < nr; j++) {
420+
float bias_ = *((float*)packed_weights_byte_ptr);
421+
packed_weights_byte_ptr += sizeof(float);
422+
if (n_idx + j < n) {
423+
bias[j] = bias_;
424+
}
425+
}
426+
}
427+
}
428+
429+
template <int weight_nbit, int nr, int kr, int sr>
430+
void unpack_weights(
431+
// Output
432+
int8_t* weight_qvals,
433+
float* weight_scales,
434+
// weight_zeros is not extracted if has_weight_zeros is false
435+
int8_t* weight_zeros,
436+
// bias is not extracted if has_bias is false
437+
float* bias,
438+
// Inputs
439+
int n,
440+
int k,
441+
int group_size,
442+
bool has_weight_zeros,
443+
bool has_bias,
444+
void* packed_weights) {
445+
assert(k % group_size == 0);
446+
assert(group_size % kr == 0);
447+
int groups_per_k = k / group_size;
448+
449+
for (int n_idx = 0; n_idx < n; n_idx += nr) {
450+
unpack_weights_at_n_idx<weight_nbit, nr, kr, sr>(
451+
weight_qvals + n_idx * k,
452+
weight_scales + n_idx * groups_per_k,
453+
weight_zeros + n_idx * groups_per_k,
454+
bias + n_idx,
455+
n_idx,
456+
n,
457+
k,
458+
group_size,
459+
has_weight_zeros,
460+
has_bias,
461+
packed_weights);
462+
}
413463
}
414464

415465
} // namespace torchao::kernels::cpu::aarch64::linear::packing

0 commit comments

Comments
 (0)