@@ -50,33 +50,33 @@ int RotaryEmbed::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
5050
5151 for (int j = 0 ; j < embed_dim / 2 ; j++)
5252 {
53- const float x1 = ptr[0 ];
54- const float x2 = ptr[1 ];
53+ const float x0 = ptr[0 ];
54+ const float x1 = ptr[1 ];
5555 const float cos_val = *cos_ptr++;
5656 const float sin_val = *sin_ptr++;
57- outptr[0 ] = x1 * cos_val - x2 * sin_val;
58- outptr[1 ] = x1 * sin_val + x2 * cos_val;
57+ outptr[0 ] = x0 * cos_val - x1 * sin_val;
58+ outptr[1 ] = x0 * sin_val + x1 * cos_val;
5959 ptr += 2 ;
6060 outptr += 2 ;
6161 }
6262 }
6363 else
6464 {
65- const float * ptr1 = head.row (i);
66- const float * ptr2 = ptr1 + embed_dim / 2 ;
65+ const float * ptr0 = head.row (i);
66+ const float * ptr1 = ptr0 + embed_dim / 2 ;
6767 const float * sin_ptr = sin_cache.row (i);
6868 const float * cos_ptr = cos_cache.row (i);
69- float * outptr1 = out_head.row (i);
70- float * outptr2 = outptr1 + embed_dim / 2 ;
69+ float * outptr0 = out_head.row (i);
70+ float * outptr1 = outptr0 + embed_dim / 2 ;
7171
7272 for (int j = 0 ; j < embed_dim / 2 ; j++)
7373 {
74+ const float x0 = *ptr0++;
7475 const float x1 = *ptr1++;
75- const float x2 = *ptr2++;
7676 const float cos_val = *cos_ptr++;
7777 const float sin_val = *sin_ptr++;
78- *outptr1 ++ = x1 * cos_val - x2 * sin_val;
79- *outptr2 ++ = x1 * sin_val + x2 * cos_val;
78+ *outptr0 ++ = x0 * cos_val - x1 * sin_val;
79+ *outptr1 ++ = x0 * sin_val + x1 * cos_val;
8080 }
8181 }
8282 }
0 commit comments