@@ -17,7 +17,6 @@ namespace ncnn {
1717
1818RotaryEmbed_x86::RotaryEmbed_x86 ()
1919{
20-
2120}
2221
2322int RotaryEmbed_x86::forward (const std::vector<Mat>& bottom_blobs,
@@ -81,64 +80,64 @@ int RotaryEmbed_x86::forward(const std::vector<Mat>& bottom_blobs,
8180 int j = 0 ;
8281
8382#if __SSE2__
84- #if __AVX512F__
83+ #if __AVX512F__
8584 for (; j + 15 < embed_dim / 2 ; j += 16 )
8685 {
8786 __m512 x0 = _mm512_loadu_ps (ptr0);
8887 __m512 x1 = _mm512_loadu_ps (ptr1);
89- __m512 c = _mm512_loadu_ps (cos_ptr);
90- __m512 s = _mm512_loadu_ps (sin_ptr);
88+ __m512 c = _mm512_loadu_ps (cos_ptr);
89+ __m512 s = _mm512_loadu_ps (sin_ptr);
9190
9291 __m512 y0 = _mm512_sub_ps (_mm512_mul_ps (x0, c), _mm512_mul_ps (x1, s));
9392 __m512 y1 = _mm512_add_ps (_mm512_mul_ps (x0, s), _mm512_mul_ps (x1, c));
9493
9594 _mm512_storeu_ps (outptr0, y0);
9695 _mm512_storeu_ps (outptr1, y1);
9796
98- ptr0 += 16 ;
99- ptr1 += 16 ;
97+ ptr0 += 16 ;
98+ ptr1 += 16 ;
10099 cos_ptr += 16 ;
101100 sin_ptr += 16 ;
102101 outptr0 += 16 ;
103102 outptr1 += 16 ;
104103 }
105- #elif __AVX__
104+ #elif __AVX__
106105 for (; j + 7 < embed_dim / 2 ; j += 8 )
107106 {
108107 __m256 x0 = _mm256_loadu_ps (ptr0);
109108 __m256 x1 = _mm256_loadu_ps (ptr1);
110- __m256 c = _mm256_loadu_ps (cos_ptr);
111- __m256 s = _mm256_loadu_ps (sin_ptr);
109+ __m256 c = _mm256_loadu_ps (cos_ptr);
110+ __m256 s = _mm256_loadu_ps (sin_ptr);
112111
113112 __m256 y0 = _mm256_sub_ps (_mm256_mul_ps (x0, c), _mm256_mul_ps (x1, s));
114113 __m256 y1 = _mm256_add_ps (_mm256_mul_ps (x0, s), _mm256_mul_ps (x1, c));
115114
116115 _mm256_storeu_ps (outptr0, y0);
117116 _mm256_storeu_ps (outptr1, y1);
118117
119- ptr0 += 8 ;
120- ptr1 += 8 ;
118+ ptr0 += 8 ;
119+ ptr1 += 8 ;
121120 cos_ptr += 8 ;
122121 sin_ptr += 8 ;
123122 outptr0 += 8 ;
124123 outptr1 += 8 ;
125124 }
126- #endif // __AVX__
125+ #endif // __AVX__
127126 for (; j + 3 < embed_dim / 2 ; j += 4 )
128127 {
129128 __m128 x0 = _mm_loadu_ps (ptr0);
130129 __m128 x1 = _mm_loadu_ps (ptr1);
131- __m128 c = _mm_loadu_ps (cos_ptr);
132- __m128 s = _mm_loadu_ps (sin_ptr);
130+ __m128 c = _mm_loadu_ps (cos_ptr);
131+ __m128 s = _mm_loadu_ps (sin_ptr);
133132
134133 __m128 y0 = _mm_sub_ps (_mm_mul_ps (x0, c), _mm_mul_ps (x1, s));
135134 __m128 y1 = _mm_add_ps (_mm_mul_ps (x0, s), _mm_mul_ps (x1, c));
136135
137136 _mm_storeu_ps (outptr0, y0);
138137 _mm_storeu_ps (outptr1, y1);
139138
140- ptr0 += 4 ;
141- ptr1 += 4 ;
139+ ptr0 += 4 ;
140+ ptr1 += 4 ;
142141 cos_ptr += 4 ;
143142 sin_ptr += 4 ;
144143 outptr0 += 4 ;
0 commit comments