Skip to content

Commit 8776dd3

Browse files
authored
quantized matmul
Differential Revision: D71370592 Pull Request resolved: #1994
1 parent c9b1490 commit 8776dd3

File tree

9 files changed

+1324
-31
lines changed

9 files changed

+1324
-31
lines changed
Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(__aarch64__) || defined(__ARM_NEON)
10+
11+
#include <algorithm>
12+
#include <cassert>
13+
#include <cstring>
14+
15+
#include <arm_neon.h>
16+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
17+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h>
18+
19+
namespace torchao::kernels::cpu::aarch64::quantized_matmul {
20+
namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal {
21+
22+
namespace {
23+
/*
24+
This function loads int8x16_t value from a, and 8 int8x16_t values from b.
25+
For each int8x16_t of b:
26+
- subl to subtarct a_zero_point from a, to get a_low, a_high
27+
- 4 int32x4 accumulated values
28+
- for i in [0, 8]:
29+
- load b[i]
30+
- subl to subtarct b_zero_point from b, to get b_low, b_high
31+
- smlal_lane to multiply a_low[i] and b_low_low.
32+
- smlal_lane to multiply a_low[i] and b_low_high.
33+
- smlal_lane to multiply a_low[i] and b_high_low.
34+
- smlal_lane to multiply a_low[i] and b_high_high.
35+
- This produces 2 int32x4_t values
36+
- for i in [0, 8]:
37+
- load b[i]
38+
- subl to subtarct b_zero_point from b, to get b_low, b_high
39+
- smlal_lane to multiply a_low[i] and b_low_low.
40+
- smlal_lane to multiply a_low[i] and b_low_high.
41+
- smlal_lane to multiply a_low[i] and b_high_low.
42+
- smlal_lane to multiply a_low[i] and b_high_high.
43+
- This produces 2 int32x4_t values
44+
Possibly better to transpose 16x16 of b and use dotprod. Left for future.
45+
*/
46+
47+
template <int lane>
48+
TORCHAO_ALWAYS_INLINE void block_mul_1x16x1(
49+
const int16x4_t& a_vec,
50+
const int8x16_t& b_vec,
51+
const int8x16_t& b_zero_point_vec,
52+
int32x4_t (&partial_sums)[4]) {
53+
int16x8_t b_vec_low =
54+
vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec));
55+
int16x8_t b_vec_high =
56+
vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec));
57+
partial_sums[0] =
58+
vmlal_lane_s16(partial_sums[0], vget_low_s16(b_vec_low), a_vec, lane);
59+
partial_sums[1] =
60+
vmlal_lane_s16(partial_sums[1], vget_high_s16(b_vec_low), a_vec, lane);
61+
partial_sums[2] =
62+
vmlal_lane_s16(partial_sums[2], vget_low_s16(b_vec_high), a_vec, lane);
63+
partial_sums[3] =
64+
vmlal_lane_s16(partial_sums[3], vget_high_s16(b_vec_high), a_vec, lane);
65+
}
66+
67+
void block_mul_1x16x16(
68+
const int8_t* a,
69+
const int8_t* b,
70+
const size_t ldb,
71+
const int8_t a_zero_point,
72+
const int8_t* b_zero_point,
73+
int32x4_t (&partial_sums)[4]) {
74+
int8x16_t a_vec = vld1q_s8(a);
75+
int8x8_t a_zero_point_vec = vdup_n_s8(a_zero_point);
76+
int8x16_t b_zero_point_vec = vld1q_s8(b_zero_point);
77+
int16x8_t a_vec_low = vsubl_s8(vget_low_s8(a_vec), a_zero_point_vec);
78+
int16x8_t a_vec_high = vsubl_s8(vget_high_s8(a_vec), a_zero_point_vec);
79+
80+
int8x16_t b_vec = vld1q_s8(b + 0 * ldb);
81+
block_mul_1x16x1<0>(
82+
vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
83+
b_vec = vld1q_s8(b + 1 * ldb);
84+
block_mul_1x16x1<1>(
85+
vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
86+
b_vec = vld1q_s8(b + 2 * ldb);
87+
block_mul_1x16x1<2>(
88+
vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
89+
b_vec = vld1q_s8(b + 3 * ldb);
90+
block_mul_1x16x1<3>(
91+
vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
92+
b_vec = vld1q_s8(b + 4 * ldb);
93+
block_mul_1x16x1<0>(
94+
vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
95+
b_vec = vld1q_s8(b + 5 * ldb);
96+
block_mul_1x16x1<1>(
97+
vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
98+
b_vec = vld1q_s8(b + 6 * ldb);
99+
block_mul_1x16x1<2>(
100+
vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
101+
b_vec = vld1q_s8(b + 7 * ldb);
102+
block_mul_1x16x1<3>(
103+
vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums);
104+
105+
// Second set of 8 channels
106+
b_vec = vld1q_s8(b + 8 * ldb);
107+
block_mul_1x16x1<0>(
108+
vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
109+
b_vec = vld1q_s8(b + 9 * ldb);
110+
block_mul_1x16x1<1>(
111+
vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
112+
b_vec = vld1q_s8(b + 10 * ldb);
113+
block_mul_1x16x1<2>(
114+
vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
115+
b_vec = vld1q_s8(b + 11 * ldb);
116+
block_mul_1x16x1<3>(
117+
vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
118+
b_vec = vld1q_s8(b + 12 * ldb);
119+
block_mul_1x16x1<0>(
120+
vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
121+
b_vec = vld1q_s8(b + 13 * ldb);
122+
block_mul_1x16x1<1>(
123+
vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
124+
b_vec = vld1q_s8(b + 14 * ldb);
125+
block_mul_1x16x1<2>(
126+
vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
127+
b_vec = vld1q_s8(b + 15 * ldb);
128+
block_mul_1x16x1<3>(
129+
vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums);
130+
}
131+
132+
TORCHAO_ALWAYS_INLINE void dequantize_1x16_int32_t(
133+
const int32x4_t (&sums)[4],
134+
const float* lhs_scales,
135+
const float* rhs_scales,
136+
float32x4_t (&outputs)[4]) {
137+
float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]);
138+
float32x4_t scales_4567 =
139+
vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]);
140+
float32x4_t scales_89ab =
141+
vmulq_n_f32(vld1q_f32(rhs_scales + 8), lhs_scales[0]);
142+
float32x4_t scales_cdef =
143+
vmulq_n_f32(vld1q_f32(rhs_scales + 12), lhs_scales[0]);
144+
145+
outputs[0] = vmulq_f32(vcvtq_f32_s32(sums[0]), scales_0123);
146+
outputs[1] = vmulq_f32(vcvtq_f32_s32(sums[1]), scales_4567);
147+
outputs[2] = vmulq_f32(vcvtq_f32_s32(sums[2]), scales_89ab);
148+
outputs[3] = vmulq_f32(vcvtq_f32_s32(sums[3]), scales_cdef);
149+
}
150+
151+
template <
152+
bool a_has_zeros,
153+
bool b_has_zeros,
154+
bool a_transposed,
155+
bool b_transposed>
156+
struct KernelImpl {
157+
static void run(
158+
int m,
159+
int n,
160+
int k,
161+
const void* lhs,
162+
int lhs_stride_m,
163+
const void* rhs,
164+
int rhs_stride_n,
165+
float32_t* output,
166+
int out_stride_m,
167+
const int8_t* lhs_zero_points,
168+
const int8_t* rhs_zero_points,
169+
const float* lhs_scales,
170+
const float* rhs_scales,
171+
const int lhs_qparams_stride,
172+
const int rhs_qparams_stride);
173+
};
174+
175+
template <>
176+
struct KernelImpl<true, true, false, false> {
177+
/**
178+
* @brief Implements quantized matrix multiplication for 8-bit channelwise
179+
* quantized matrices
180+
*
181+
* This specialized implementation handles the case where:
182+
* - Both LHS and RHS have zero points (true, true)
183+
* - Neither LHS nor RHS are transposed (false, false)
184+
*
185+
* The function performs a quantized matrix multiplication C = A * B where:
186+
* - A is an m×k matrix (LHS)
187+
* - B is a k×n matrix (RHS)
188+
* - C is an m×n matrix (output)
189+
*
190+
* The implementation uses NEON intrinsics for vectorized computation and
191+
* processes data in blocks of 16×16 for optimal performance on ARM
192+
* architecture.
193+
*
194+
* @param m Number of rows in LHS and output
195+
* @param n Number of columns in RHS and output
196+
* @param k Number of columns in LHS and rows in RHS
197+
* @param lhs Pointer to LHS matrix data (int8_t)
198+
* @param lhs_stride_m Stride between rows of LHS
199+
* @param rhs Pointer to RHS matrix data (int8_t)
200+
* @param rhs_stride_n Stride between rows of RHS
201+
* @param output Pointer to output matrix (float32_t)
202+
* @param out_stride_m Stride between rows of output
203+
* @param lhs_zero_points Zero points for LHS quantization (per-channel)
204+
* @param rhs_zero_points Zero points for RHS quantization (per-channel)
205+
* @param lhs_scales Scales for LHS quantization (per-channel)
206+
* @param rhs_scales Scales for RHS quantization (per-channel)
207+
* @param lhs_qparams_stride Stride for LHS quantization parameters
208+
* @param rhs_qparams_stride Stride for RHS quantization parameters
209+
*/
210+
static void run(
211+
int m,
212+
int n,
213+
int k,
214+
const void* lhs,
215+
int lhs_stride_m,
216+
const void* rhs,
217+
int rhs_stride_n,
218+
float32_t* output,
219+
int out_stride_m,
220+
const int8_t* lhs_zero_points,
221+
const int8_t* rhs_zero_points,
222+
const float* lhs_scales,
223+
const float* rhs_scales,
224+
const int lhs_qparams_stride,
225+
const int rhs_qparams_stride) {
226+
// If lhs_zero_points and rhs_zero_points are not contiguous, transpose
227+
std::unique_ptr<int8_t[]> lhs_zero_points_transposed =
228+
std::make_unique<int8_t[]>(m);
229+
std::unique_ptr<float[]> lhs_scales_transposed =
230+
std::make_unique<float[]>(m);
231+
if (lhs_qparams_stride > 1) {
232+
utils::transpose_scales_and_zero_points(
233+
lhs_zero_points,
234+
lhs_scales,
235+
lhs_zero_points_transposed.get(),
236+
lhs_scales_transposed.get(),
237+
m,
238+
lhs_qparams_stride);
239+
lhs_zero_points = lhs_zero_points_transposed.get();
240+
lhs_scales = lhs_scales_transposed.get();
241+
}
242+
std::unique_ptr<int8_t[]> rhs_zero_points_transposed =
243+
std::make_unique<int8_t[]>(n);
244+
std::unique_ptr<float[]> rhs_scales_transposed =
245+
std::make_unique<float[]>(n);
246+
if (rhs_qparams_stride > 1) {
247+
utils::transpose_scales_and_zero_points(
248+
rhs_zero_points,
249+
rhs_scales,
250+
rhs_zero_points_transposed.get(),
251+
rhs_scales_transposed.get(),
252+
n,
253+
rhs_qparams_stride);
254+
rhs_zero_points = rhs_zero_points_transposed.get();
255+
rhs_scales = rhs_scales_transposed.get();
256+
}
257+
258+
for (int m_idx = 0; m_idx < m; m_idx++) {
259+
// Loop over 16 cols at a time
260+
// Access to partial tiles must be protected:w
261+
constexpr int nr = 16;
262+
constexpr int kr = 16;
263+
assert(n >= nr);
264+
for (int n_idx = 0; n_idx < n; n_idx += nr) {
265+
// If remaining is < nr, that must mean that (nr - remaining) items
266+
// dont need to be computed.
267+
// In order to avoid out-of-bounds access, we need to rewind n_indx a
268+
// bit
269+
// |-------------------|-------------------|
270+
// 0-------------------8-------------------16
271+
// 0-------------------8-----10
272+
// If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to
273+
// 8 - (8 - 10) = 2
274+
int remaining = std::min(n - n_idx, nr);
275+
n_idx = n_idx - (nr - remaining);
276+
// Set activation_ptr to start of activation qvals for row m_idx
277+
const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m;
278+
const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx;
279+
int32x4_t int32_sums[nr / 4] = {vdupq_n_s32(0)};
280+
281+
// Loop k_idx by group
282+
int k_idx = 0;
283+
for (; (k_idx + kr) <= k; k_idx += kr) {
284+
block_mul_1x16x16(
285+
lhs_ptr,
286+
rhs_ptr,
287+
rhs_stride_n,
288+
lhs_zero_points[m_idx],
289+
rhs_zero_points + n_idx,
290+
int32_sums);
291+
lhs_ptr += kr;
292+
rhs_ptr += kr * rhs_stride_n;
293+
}
294+
295+
int8x16_t b_zero_point_vec = vld1q_s8(rhs_zero_points + n_idx);
296+
for (int ki = 0; ki < (k - k_idx); ++ki) {
297+
// For each of the remaining k values
298+
// Load 1 int8_t from lhs
299+
// Load 16 int8_t from rhs
300+
// And multiply + add into the 16 accumulators
301+
// arranged as int32x4_t[4]
302+
int16_t a_val = static_cast<int16_t>(lhs_ptr[ki]) -
303+
static_cast<int16_t>(lhs_zero_points[m_idx]);
304+
int8x16_t b_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n);
305+
int16x8_t b_vec_low =
306+
vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec));
307+
int16x8_t b_vec_high =
308+
vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec));
309+
int32_sums[0] =
310+
vmlal_n_s16(int32_sums[0], vget_low_s16(b_vec_low), a_val);
311+
int32_sums[1] =
312+
vmlal_n_s16(int32_sums[1], vget_high_s16(b_vec_low), a_val);
313+
int32_sums[2] =
314+
vmlal_n_s16(int32_sums[2], vget_low_s16(b_vec_high), a_val);
315+
int32_sums[3] =
316+
vmlal_n_s16(int32_sums[3], vget_high_s16(b_vec_high), a_val);
317+
}
318+
319+
float32x4_t res[4];
320+
dequantize_1x16_int32_t(
321+
int32_sums, lhs_scales + m_idx, rhs_scales + n_idx, res);
322+
323+
// Store result
324+
// Because we adjust n_idx, we may end up writing the same location
325+
// twice
326+
float* store_loc = output + m_idx * out_stride_m + n_idx;
327+
vst1q_f32(store_loc, res[0]);
328+
vst1q_f32(store_loc + 4, res[1]);
329+
vst1q_f32(store_loc + 8, res[2]);
330+
vst1q_f32(store_loc + 12, res[3]);
331+
} // n_idx
332+
} // m_idx
333+
}
334+
};
335+
336+
} // namespace
337+
338+
} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal
339+
340+
namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal {
341+
template <
342+
bool a_has_zeros,
343+
bool b_has_zeros,
344+
bool a_transposed,
345+
bool b_transposed>
346+
void kernel(
347+
int m,
348+
int n,
349+
int k,
350+
const void* lhs,
351+
int lhs_stride_m,
352+
const void* rhs,
353+
int rhs_stride_n,
354+
float32_t* output,
355+
int out_stride_m,
356+
const int8_t* lhs_zero_points,
357+
const int8_t* rhs_zero_points,
358+
const float* lhs_scales,
359+
const float* rhs_scales,
360+
const int lhs_qparams_stride,
361+
const int rhs_qparams_stride) {
362+
torchao::kernels::cpu::aarch64::quantized_matmul::
363+
channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal::
364+
KernelImpl<a_has_zeros, b_has_zeros, a_transposed, b_transposed>::run(
365+
m,
366+
n,
367+
k,
368+
lhs,
369+
lhs_stride_m,
370+
rhs,
371+
rhs_stride_n,
372+
output,
373+
out_stride_m,
374+
lhs_zero_points,
375+
rhs_zero_points,
376+
lhs_scales,
377+
rhs_scales,
378+
lhs_qparams_stride,
379+
rhs_qparams_stride);
380+
}
381+
} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal
382+
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
383+
384+
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)