99static int sve_cnt = PR_SVE_VL_LEN_MASK & prctl (PR_SVE_GET_VL);
1010#endif
1111
12- #if defined(__AVX2__)
12+ // Ref:llama.cpp(https://github.com/ggml-org/llama.cpp) ggml/src/ggml-cpu/arch/x86/quants.c
13+ #define MM256_SET_M128I (a, b ) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1 )
1314
15+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
1416namespace {
15-
1617// multiply int8_t, add results pairwise twice
17- inline __m128i mul_sum_i8_pairs (const __m128i x, const __m128i y) {
18+ static inline __m128i mul_sum_i8_pairs (const __m128i x, const __m128i y) {
1819 // Get absolute values of x vectors
1920 const __m128i ax = _mm_sign_epi8 (x, x);
2021 // Sign the values of the y vectors
@@ -25,44 +26,82 @@ inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
2526 return _mm_madd_epi16 (ones, dot);
2627}
2728
29+ #if __AVX__ || __AVX2__ || __AVX512F__
2830// horizontally add 8 floats
29- inline float hsum_float_8 (const __m256 x) {
31+ static inline float hsum_float_8 (const __m256 x) {
3032 __m128 res = _mm256_extractf128_ps (x, 1 );
3133 res = _mm_add_ps (res, _mm256_castps256_ps128 (x));
3234 res = _mm_add_ps (res, _mm_movehl_ps (res, res));
3335 res = _mm_add_ss (res, _mm_movehdup_ps (res));
3436 return _mm_cvtss_f32 (res);
3537}
3638
39+ // horizontally add 8 int32_t
40+ static inline int hsum_i32_8 (const __m256i a) {
41+ const __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (a), _mm256_extractf128_si256 (a, 1 ));
42+ const __m128i hi64 = _mm_unpackhi_epi64 (sum128, sum128);
43+ const __m128i sum64 = _mm_add_epi32 (hi64, sum128);
44+ const __m128i hi32 = _mm_shuffle_epi32 (sum64, _MM_SHUFFLE (2 , 3 , 0 , 1 ));
45+ return _mm_cvtsi128_si32 (_mm_add_epi32 (sum64, hi32));
46+ }
47+
48+ // horizontally add 4 int32_t
49+ static inline int hsum_i32_4 (const __m128i a) {
50+ const __m128i hi64 = _mm_unpackhi_epi64 (a, a);
51+ const __m128i sum64 = _mm_add_epi32 (hi64, a);
52+ const __m128i hi32 = _mm_shuffle_epi32 (sum64, _MM_SHUFFLE (2 , 3 , 0 , 1 ));
53+ return _mm_cvtsi128_si32 (_mm_add_epi32 (sum64, hi32));
54+ }
55+
56+ #if defined(__AVX2__) || defined(__AVX512F__)
57+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
58+ static inline __m256i bytes_from_bits_32 (const uint8_t * x) {
59+ uint32_t x32;
60+ memcpy (&x32, x, sizeof (uint32_t ));
61+ const __m256i shuf_mask = _mm256_set_epi64x (
62+ 0x0303030303030303 , 0x0202020202020202 ,
63+ 0x0101010101010101 , 0x0000000000000000 );
64+ __m256i bytes = _mm256_shuffle_epi8 (_mm256_set1_epi32 (x32), shuf_mask);
65+ const __m256i bit_mask = _mm256_set1_epi64x (0x7fbfdfeff7fbfdfe );
66+ bytes = _mm256_or_si256 (bytes, bit_mask);
67+ return _mm256_cmpeq_epi8 (bytes, _mm256_set1_epi64x (-1 ));
68+ }
69+
3770// Unpack 32 4-bit fields into 32 bytes
3871// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
39- inline __m256i bytes_from_nibbles_32 (const uint8_t *rsi) {
72+ static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi)
73+ {
4074 const __m128i tmp = _mm_loadu_si128 ((const __m128i *)rsi);
41- const __m256i bytes = _mm256_set_m128i (_mm_srli_epi16 (tmp, 4 ), tmp);
42- const __m256i lowMask = _mm256_set1_epi8 (0xF );
75+ const __m256i bytes = MM256_SET_M128I (_mm_srli_epi16 (tmp, 4 ), tmp);
76+ const __m256i lowMask = _mm256_set1_epi8 ( 0xF );
4377 return _mm256_and_si256 (lowMask, bytes);
4478}
4579
4680// add int16_t pairwise and return as float vector
47- inline __m256 sum_i16_pairs_float (const __m256i x) {
81+ static inline __m256 sum_i16_pairs_float (const __m256i x) {
4882 const __m256i ones = _mm256_set1_epi16 (1 );
4983 const __m256i summed_pairs = _mm256_madd_epi16 (ones, x);
5084 return _mm256_cvtepi32_ps (summed_pairs);
5185}
5286
53- inline __m256 mul_sum_us8_pairs_float (const __m256i ax, const __m256i sy) {
54- #if defined(__AVXVNNI__) || (defined( __AVX512VNNI__) && defined(__AVX512VL__) )
87+ static inline __m256 mul_sum_us8_pairs_float (const __m256i ax, const __m256i sy) {
88+ #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
5589 const __m256i zero = _mm256_setzero_si256 ();
5690 const __m256i summed_pairs = _mm256_dpbusd_epi32 (zero, ax, sy);
5791 return _mm256_cvtepi32_ps (summed_pairs);
92+ #elif defined(__AVXVNNI__)
93+ const __m256i zero = _mm256_setzero_si256 ();
94+ const __m256i summed_pairs = _mm256_dpbusd_avx_epi32 (zero, ax, sy);
95+ return _mm256_cvtepi32_ps (summed_pairs);
5896#else
5997 // Perform multiplication and create 16-bit values
6098 const __m256i dot = _mm256_maddubs_epi16 (ax, sy);
6199 return sum_i16_pairs_float (dot);
62100#endif
63101}
64102
65- inline __m256 mul_sum_i8_pairs_float (const __m256i x, const __m256i y) {
103+ // multiply int8_t, add results pairwise twice and return as float vector
104+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x, const __m256i y) {
66105#if __AVXVNNIINT8__
67106 const __m256i zero = _mm256_setzero_si256 ();
68107 const __m256i summed_pairs = _mm256_dpbssd_epi32 (zero, x, y);
@@ -76,9 +115,155 @@ inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
76115#endif
77116}
78117
118+ static inline __m128i packNibbles ( __m256i bytes )
119+ {
120+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
121+ #if __AVX512F__
122+ const __m256i bytes_srli_4 = _mm256_srli_epi16 (bytes, 4 ); // 0000_0000_abcd_0000
123+ bytes = _mm256_or_si256 (bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
124+ return _mm256_cvtepi16_epi8 (bytes); // abcd_efgh
125+ #else
126+ const __m256i lowByte = _mm256_set1_epi16 ( 0xFF );
127+ __m256i high = _mm256_andnot_si256 ( lowByte, bytes );
128+ __m256i low = _mm256_and_si256 ( lowByte, bytes );
129+ high = _mm256_srli_epi16 ( high, 4 );
130+ bytes = _mm256_or_si256 ( low, high );
131+
132+ // Compress uint16_t lanes into bytes
133+ __m128i r0 = _mm256_castsi256_si128 ( bytes );
134+ __m128i r1 = _mm256_extracti128_si256 ( bytes, 1 );
135+ return _mm_packus_epi16 ( r0, r1 );
136+ #endif
137+ }
138+ #elif defined(__AVX__)
139+ static inline __m128i packNibbles ( __m128i bytes1, __m128i bytes2 )
140+ {
141+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
142+ const __m128i lowByte = _mm_set1_epi16 ( 0xFF );
143+ __m128i high = _mm_andnot_si128 ( lowByte, bytes1 );
144+ __m128i low = _mm_and_si128 ( lowByte, bytes1 );
145+ high = _mm_srli_epi16 ( high, 4 );
146+ bytes1 = _mm_or_si128 ( low, high );
147+ high = _mm_andnot_si128 ( lowByte, bytes2 );
148+ low = _mm_and_si128 ( lowByte, bytes2 );
149+ high = _mm_srli_epi16 ( high, 4 );
150+ bytes2 = _mm_or_si128 ( low, high );
151+
152+ return _mm_packus_epi16 ( bytes1, bytes2);
153+ }
154+
155+ static inline __m128i mul_add_epi8_sse (const __m128i x, const __m128i y) {
156+ const __m128i ax = _mm_sign_epi8 (x, x);
157+ const __m128i sy = _mm_sign_epi8 (y, x);
158+ return _mm_maddubs_epi16 (ax, sy);
159+ }
160+
161+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
162+ static inline __m256i bytes_from_bits_32 (const uint8_t * x) {
163+ uint32_t x32;
164+ memcpy (&x32, x, sizeof (uint32_t ));
165+ const __m128i shuf_maskl = _mm_set_epi64x (0x0101010101010101 , 0x0000000000000000 );
166+ const __m128i shuf_maskh = _mm_set_epi64x (0x0303030303030303 , 0x0202020202020202 );
167+ __m128i bytesl = _mm_shuffle_epi8 (_mm_set1_epi32 (x32), shuf_maskl);
168+ __m128i bytesh = _mm_shuffle_epi8 (_mm_set1_epi32 (x32), shuf_maskh);
169+ const __m128i bit_mask = _mm_set1_epi64x (0x7fbfdfeff7fbfdfe );
170+ bytesl = _mm_or_si128 (bytesl, bit_mask);
171+ bytesh = _mm_or_si128 (bytesh, bit_mask);
172+ bytesl = _mm_cmpeq_epi8 (bytesl, _mm_set1_epi64x (-1 ));
173+ bytesh = _mm_cmpeq_epi8 (bytesh, _mm_set1_epi64x (-1 ));
174+ return MM256_SET_M128I (bytesh, bytesl);
175+ }
176+
177+ // Unpack 32 4-bit fields into 32 bytes
178+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
179+ static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi)
180+ {
181+ // Load 16 bytes from memory
182+ __m128i tmpl = _mm_loadu_si128 ((const __m128i *)rsi);
183+ __m128i tmph = _mm_srli_epi16 (tmpl, 4 );
184+ const __m128i lowMask = _mm_set1_epi8 (0xF );
185+ tmpl = _mm_and_si128 (lowMask, tmpl);
186+ tmph = _mm_and_si128 (lowMask, tmph);
187+ return MM256_SET_M128I (tmph, tmpl);
188+ }
189+
190+ // add int16_t pairwise and return as float vector
191+ static inline __m256 sum_i16_pairs_float (const __m128i xh, const __m128i xl) {
192+ const __m128i ones = _mm_set1_epi16 (1 );
193+ const __m128i summed_pairsl = _mm_madd_epi16 (ones, xl);
194+ const __m128i summed_pairsh = _mm_madd_epi16 (ones, xh);
195+ const __m256i summed_pairs = MM256_SET_M128I (summed_pairsh, summed_pairsl);
196+ return _mm256_cvtepi32_ps (summed_pairs);
197+ }
198+
199+ static inline __m256 mul_sum_us8_pairs_float (const __m256i ax, const __m256i sy) {
200+ const __m128i axl = _mm256_castsi256_si128 (ax);
201+ const __m128i axh = _mm256_extractf128_si256 (ax, 1 );
202+ const __m128i syl = _mm256_castsi256_si128 (sy);
203+ const __m128i syh = _mm256_extractf128_si256 (sy, 1 );
204+ // Perform multiplication and create 16-bit values
205+ const __m128i dotl = _mm_maddubs_epi16 (axl, syl);
206+ const __m128i doth = _mm_maddubs_epi16 (axh, syh);
207+ return sum_i16_pairs_float (doth, dotl);
79208}
80209
81- #endif // __AVX2__
210+ // multiply int8_t, add results pairwise twice and return as float vector
211+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x, const __m256i y) {
212+ const __m128i xl = _mm256_castsi256_si128 (x);
213+ const __m128i xh = _mm256_extractf128_si256 (x, 1 );
214+ const __m128i yl = _mm256_castsi256_si128 (y);
215+ const __m128i yh = _mm256_extractf128_si256 (y, 1 );
216+ // Get absolute values of x vectors
217+ const __m128i axl = _mm_sign_epi8 (xl, xl);
218+ const __m128i axh = _mm_sign_epi8 (xh, xh);
219+ // Sign the values of the y vectors
220+ const __m128i syl = _mm_sign_epi8 (yl, xl);
221+ const __m128i syh = _mm_sign_epi8 (yh, xh);
222+ // Perform multiplication and create 16-bit values
223+ const __m128i dotl = _mm_maddubs_epi16 (axl, syl);
224+ const __m128i doth = _mm_maddubs_epi16 (axh, syh);
225+ return sum_i16_pairs_float (doth, dotl);
226+ }
227+
228+ // larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
229+ static inline __m256 mul_sum_i8_quad_float (const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
230+ const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
231+ const __m128i mone = _mm_set1_epi16 (1 );
232+
233+ const __m128i p16_1_0 = mul_add_epi8_sse (x_1_0, y_1_0);
234+ const __m128i p16_1_1 = mul_add_epi8_sse (x_1_1, y_1_1);
235+ const __m128i p16_2_0 = mul_add_epi8_sse (x_2_0, y_2_0);
236+ const __m128i p16_2_1 = mul_add_epi8_sse (x_2_1, y_2_1);
237+ const __m128i p_1_0 = _mm_madd_epi16 (p16_1_0, mone);
238+ const __m128i p_1_1 = _mm_madd_epi16 (p16_1_1, mone);
239+ const __m128i p_2_0 = _mm_madd_epi16 (p16_2_0, mone);
240+ const __m128i p_2_1 = _mm_madd_epi16 (p16_2_1, mone);
241+ const __m128i p_1 = _mm_add_epi32 (p_1_0, p_1_1);
242+ const __m128i p_2 = _mm_add_epi32 (p_2_0, p_2_1);
243+ return _mm256_cvtepi32_ps (MM256_SET_M128I (p_2, p_1));
244+ }
245+
246+ // quad fp16 delta calculation
247+ static inline __m256 quad_fp16_delta_float (const float x0, const float y0, const float x1, const float y1) {
248+ // GGML_CPU_FP16_TO_FP32 is faster than Intel F16C
249+ return _mm256_set_m128 (_mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x1) * GGML_CPU_FP16_TO_FP32 (y1)),
250+ _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x0) * GGML_CPU_FP16_TO_FP32 (y0)));
251+ }
252+ #endif
253+ #elif defined(__SSSE3__)
254+ // horizontally add 4x4 floats
255+ static inline float hsum_float_4x4 (const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
256+ __m128 res_0 =_mm_hadd_ps (a, b);
257+ __m128 res_1 =_mm_hadd_ps (c, d);
258+ __m128 res =_mm_hadd_ps (res_0, res_1);
259+ res =_mm_hadd_ps (res, res);
260+ res =_mm_hadd_ps (res, res);
261+
262+ return _mm_cvtss_f32 (res);
263+ }
264+ #endif // __AVX__ || __AVX2__ || __AVX512F__
265+ }
266+ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
82267
83268namespace az ::cpu {
84269
0 commit comments