Skip to content

Commit 97e26fd

Browse files
committed
Preliminary SSSE3 support
1 parent cef4b86 commit 97e26fd

6 files changed

Lines changed: 508 additions & 20 deletions

File tree

smallthinker/powerinfer/libaz/az/core/fp16.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "az/core/intrinsics.hpp"
88
#include "stdint.h"
9+
#include <math.h>
910

1011
#if defined(__cplusplus)
1112
extern "C" {

smallthinker/powerinfer/libaz/az/cpu/quant_types.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,27 @@ void quantize_row_q8_0(block_q8_0 *out, const float *in, size_t n) {
187187
#endif // __AVX2__
188188
}
189189
#else
190-
abort();
190+
for (int i = 0; i < nb; i++) {
191+
float amax = 0.0f; // absolute max
192+
193+
for (size_t j = 0; j < block_q8_0::block_size; j++) {
194+
const float v = in[i*block_q8_0::block_size + j];
195+
if (amax < fabsf(v)) {
196+
amax = fabsf(v);
197+
}
198+
}
199+
200+
const float d = amax / ((1 << 7) - 1);
201+
const float id = d ? 1.0f/d : 0.0f;
202+
203+
out[i].d = AZ_FP32_TO_FP16(d);
204+
205+
for (size_t j = 0; j < block_q8_0::block_size; ++j) {
206+
const float x0 = in[i*block_q8_0::block_size + j]*id;
207+
208+
out[i].qs[j] = roundf(x0);
209+
}
210+
}
191211
#endif
192212
}
193213

smallthinker/powerinfer/libaz/az/cpu/vec_dot.cpp

Lines changed: 197 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
static int sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
1010
#endif
1111

12-
#if defined(__AVX2__)
12+
// Ref:llama.cpp(https://github.com/ggml-org/llama.cpp) ggml/src/ggml-cpu/arch/x86/quants.c
13+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
1314

15+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
1416
namespace {
15-
1617
// multiply int8_t, add results pairwise twice
17-
inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
18+
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
1819
// Get absolute values of x vectors
1920
const __m128i ax = _mm_sign_epi8(x, x);
2021
// Sign the values of the y vectors
@@ -25,44 +26,82 @@ inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
2526
return _mm_madd_epi16(ones, dot);
2627
}
2728

29+
#if __AVX__ || __AVX2__ || __AVX512F__
2830
// horizontally add 8 floats
29-
inline float hsum_float_8(const __m256 x) {
31+
static inline float hsum_float_8(const __m256 x) {
3032
__m128 res = _mm256_extractf128_ps(x, 1);
3133
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
3234
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
3335
res = _mm_add_ss(res, _mm_movehdup_ps(res));
3436
return _mm_cvtss_f32(res);
3537
}
3638

39+
// horizontally add 8 int32_t
40+
static inline int hsum_i32_8(const __m256i a) {
41+
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
42+
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
43+
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
44+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
45+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
46+
}
47+
48+
// horizontally add 4 int32_t
49+
static inline int hsum_i32_4(const __m128i a) {
50+
const __m128i hi64 = _mm_unpackhi_epi64(a, a);
51+
const __m128i sum64 = _mm_add_epi32(hi64, a);
52+
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
53+
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
54+
}
55+
56+
#if defined(__AVX2__) || defined(__AVX512F__)
57+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
58+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
59+
uint32_t x32;
60+
memcpy(&x32, x, sizeof(uint32_t));
61+
const __m256i shuf_mask = _mm256_set_epi64x(
62+
0x0303030303030303, 0x0202020202020202,
63+
0x0101010101010101, 0x0000000000000000);
64+
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
65+
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
66+
bytes = _mm256_or_si256(bytes, bit_mask);
67+
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
68+
}
69+
3770
// Unpack 32 4-bit fields into 32 bytes
3871
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
39-
inline __m256i bytes_from_nibbles_32(const uint8_t *rsi) {
72+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
73+
{
4074
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
41-
const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
42-
const __m256i lowMask = _mm256_set1_epi8(0xF);
75+
const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
76+
const __m256i lowMask = _mm256_set1_epi8( 0xF );
4377
return _mm256_and_si256(lowMask, bytes);
4478
}
4579

4680
// add int16_t pairwise and return as float vector
47-
inline __m256 sum_i16_pairs_float(const __m256i x) {
81+
static inline __m256 sum_i16_pairs_float(const __m256i x) {
4882
const __m256i ones = _mm256_set1_epi16(1);
4983
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
5084
return _mm256_cvtepi32_ps(summed_pairs);
5185
}
5286

53-
inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
54-
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
87+
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
88+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
5589
const __m256i zero = _mm256_setzero_si256();
5690
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
5791
return _mm256_cvtepi32_ps(summed_pairs);
92+
#elif defined(__AVXVNNI__)
93+
const __m256i zero = _mm256_setzero_si256();
94+
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
95+
return _mm256_cvtepi32_ps(summed_pairs);
5896
#else
5997
// Perform multiplication and create 16-bit values
6098
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
6199
return sum_i16_pairs_float(dot);
62100
#endif
63101
}
64102

65-
inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
103+
// multiply int8_t, add results pairwise twice and return as float vector
104+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
66105
#if __AVXVNNIINT8__
67106
const __m256i zero = _mm256_setzero_si256();
68107
const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
@@ -76,9 +115,155 @@ inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
76115
#endif
77116
}
78117

118+
static inline __m128i packNibbles( __m256i bytes )
119+
{
120+
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
121+
#if __AVX512F__
122+
const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
123+
bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
124+
return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
125+
#else
126+
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
127+
__m256i high = _mm256_andnot_si256( lowByte, bytes );
128+
__m256i low = _mm256_and_si256( lowByte, bytes );
129+
high = _mm256_srli_epi16( high, 4 );
130+
bytes = _mm256_or_si256( low, high );
131+
132+
// Compress uint16_t lanes into bytes
133+
__m128i r0 = _mm256_castsi256_si128( bytes );
134+
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
135+
return _mm_packus_epi16( r0, r1 );
136+
#endif
137+
}
138+
#elif defined(__AVX__)
139+
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
140+
{
141+
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
142+
const __m128i lowByte = _mm_set1_epi16( 0xFF );
143+
__m128i high = _mm_andnot_si128( lowByte, bytes1 );
144+
__m128i low = _mm_and_si128( lowByte, bytes1 );
145+
high = _mm_srli_epi16( high, 4 );
146+
bytes1 = _mm_or_si128( low, high );
147+
high = _mm_andnot_si128( lowByte, bytes2 );
148+
low = _mm_and_si128( lowByte, bytes2 );
149+
high = _mm_srli_epi16( high, 4 );
150+
bytes2 = _mm_or_si128( low, high );
151+
152+
return _mm_packus_epi16( bytes1, bytes2);
153+
}
154+
155+
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
156+
const __m128i ax = _mm_sign_epi8(x, x);
157+
const __m128i sy = _mm_sign_epi8(y, x);
158+
return _mm_maddubs_epi16(ax, sy);
159+
}
160+
161+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
162+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
163+
uint32_t x32;
164+
memcpy(&x32, x, sizeof(uint32_t));
165+
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
166+
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
167+
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
168+
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
169+
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
170+
bytesl = _mm_or_si128(bytesl, bit_mask);
171+
bytesh = _mm_or_si128(bytesh, bit_mask);
172+
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
173+
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
174+
return MM256_SET_M128I(bytesh, bytesl);
175+
}
176+
177+
// Unpack 32 4-bit fields into 32 bytes
178+
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
179+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
180+
{
181+
// Load 16 bytes from memory
182+
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
183+
__m128i tmph = _mm_srli_epi16(tmpl, 4);
184+
const __m128i lowMask = _mm_set1_epi8(0xF);
185+
tmpl = _mm_and_si128(lowMask, tmpl);
186+
tmph = _mm_and_si128(lowMask, tmph);
187+
return MM256_SET_M128I(tmph, tmpl);
188+
}
189+
190+
// add int16_t pairwise and return as float vector
191+
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
192+
const __m128i ones = _mm_set1_epi16(1);
193+
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
194+
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
195+
const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
196+
return _mm256_cvtepi32_ps(summed_pairs);
197+
}
198+
199+
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
200+
const __m128i axl = _mm256_castsi256_si128(ax);
201+
const __m128i axh = _mm256_extractf128_si256(ax, 1);
202+
const __m128i syl = _mm256_castsi256_si128(sy);
203+
const __m128i syh = _mm256_extractf128_si256(sy, 1);
204+
// Perform multiplication and create 16-bit values
205+
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
206+
const __m128i doth = _mm_maddubs_epi16(axh, syh);
207+
return sum_i16_pairs_float(doth, dotl);
79208
}
80209

81-
#endif // __AVX2__
210+
// multiply int8_t, add results pairwise twice and return as float vector
211+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
212+
const __m128i xl = _mm256_castsi256_si128(x);
213+
const __m128i xh = _mm256_extractf128_si256(x, 1);
214+
const __m128i yl = _mm256_castsi256_si128(y);
215+
const __m128i yh = _mm256_extractf128_si256(y, 1);
216+
// Get absolute values of x vectors
217+
const __m128i axl = _mm_sign_epi8(xl, xl);
218+
const __m128i axh = _mm_sign_epi8(xh, xh);
219+
// Sign the values of the y vectors
220+
const __m128i syl = _mm_sign_epi8(yl, xl);
221+
const __m128i syh = _mm_sign_epi8(yh, xh);
222+
// Perform multiplication and create 16-bit values
223+
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
224+
const __m128i doth = _mm_maddubs_epi16(axh, syh);
225+
return sum_i16_pairs_float(doth, dotl);
226+
}
227+
228+
// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
229+
static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
230+
const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
231+
const __m128i mone = _mm_set1_epi16(1);
232+
233+
const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
234+
const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
235+
const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
236+
const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
237+
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
238+
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
239+
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
240+
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
241+
const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
242+
const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
243+
return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
244+
}
245+
246+
// quad fp16 delta calculation
247+
static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
248+
// GGML_CPU_FP16_TO_FP32 is faster than Intel F16C
249+
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
250+
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
251+
}
252+
#endif
253+
#elif defined(__SSSE3__)
254+
// horizontally add 4x4 floats
255+
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
256+
__m128 res_0 =_mm_hadd_ps(a, b);
257+
__m128 res_1 =_mm_hadd_ps(c, d);
258+
__m128 res =_mm_hadd_ps(res_0, res_1);
259+
res =_mm_hadd_ps(res, res);
260+
res =_mm_hadd_ps(res, res);
261+
262+
return _mm_cvtss_f32(res);
263+
}
264+
#endif // __AVX__ || __AVX2__ || __AVX512F__
265+
}
266+
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
82267

83268
namespace az::cpu {
84269

0 commit comments

Comments
 (0)