Skip to content

Commit fd8a2b3

Browse files
committed
rotaryembed layer
1 parent 69652f4 commit fd8a2b3

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

src/layer/rotaryembed.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "rotaryembed.h"
5+
6+
namespace ncnn {
7+
8+
RotaryEmbed::RotaryEmbed()
9+
{
10+
}
11+
12+
int RotaryEmbed::load_param(const ParamDict& pd)
13+
{
14+
interleaved = pd.get(0, 0);
15+
16+
return 0;
17+
}
18+
19+
int RotaryEmbed::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
20+
{
21+
// assert bottom_blobs.size() == 3
22+
23+
const Mat& bottom_blob = bottom_blobs[0];
24+
const Mat& sin_cache = bottom_blobs[1];
25+
const Mat& cos_cache = bottom_blobs[2];
26+
27+
const int embed_dim = bottom_blob.w;
28+
const int seqlen = bottom_blob.h;
29+
const int num_heads = bottom_blob.c;
30+
31+
Mat& top_blob = top_blobs[0];
32+
top_blob.create_like(bottom_blob, opt.blob_allocator);
33+
if (top_blob.empty())
34+
return -100;
35+
36+
#pragma omp parallel for num_threads(opt.num_threads)
37+
for (int q = 0; q < num_heads; q++)
38+
{
39+
const Mat head = bottom_blob.channel(q);
40+
Mat out_head = top_blob.channel(q);
41+
42+
for (int i = 0; i < seqlen; i++)
43+
{
44+
if (interleaved)
45+
{
46+
const float* ptr = head.row(i);
47+
const float* sin_ptr = sin_cache.row(i);
48+
const float* cos_ptr = cos_cache.row(i);
49+
float* outptr = out_head.row(i);
50+
51+
for (int j = 0; j < embed_dim / 2; j++)
52+
{
53+
const float x1 = ptr[0];
54+
const float x2 = ptr[1];
55+
const float sin_val = *sin_ptr++;
56+
const float cos_val = *cos_ptr++;
57+
outptr[0] = x1 * cos_val - x2 * sin_val;
58+
outptr[1] = x1 * sin_val + x2 * cos_val;
59+
ptr += 2;
60+
outptr += 2;
61+
}
62+
}
63+
else
64+
{
65+
const float* ptr1 = head.row(i);
66+
const float* ptr2 = ptr1 + embed_dim / 2;
67+
const float* sin_ptr = sin_cache.row(i);
68+
const float* cos_ptr = cos_cache.row(i);
69+
float* outptr1 = out_head.row(i);
70+
float* outptr2 = outptr1 + embed_dim / 2;
71+
72+
for (int j = 0; j < embed_dim / 2; j++)
73+
{
74+
const float x1 = *ptr1++;
75+
const float x2 = *ptr2++;
76+
const float sin_val = *sin_ptr++;
77+
const float cos_val = *cos_ptr++;
78+
*outptr1++ = x1 * cos_val - x2 * sin_val;
79+
*outptr2++ = x1 * sin_val + x2 * cos_val;
80+
}
81+
}
82+
}
83+
}
84+
85+
return 0;
86+
}
87+
88+
} // namespace ncnn

src/layer/rotaryembed.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#ifndef LAYER_ROTARYEMBED_H
5+
#define LAYER_ROTARYEMBED_H
6+
7+
#include "layer.h"
8+
9+
namespace ncnn {
10+
11+
class RotaryEmbed : public Layer
12+
{
13+
public:
14+
RotaryEmbed();
15+
16+
virtual int load_param(const ParamDict& pd);
17+
18+
virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const;
19+
20+
public:
21+
int interleaved;
22+
};
23+
24+
} // namespace ncnn
25+
26+
#endif // LAYER_ROTARYEMBED_H

0 commit comments

Comments
 (0)