You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: tests/test_gemm_nt.cpp
+141-1Lines changed: 141 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -70,6 +70,146 @@ static int test_gemm_0(int M, int N, int K)
70
70
|| test_gemm_nt(M, N, K, 1, 1, 0, 1, 1);
71
71
}
72
72
73
+
staticinttest_gemm_bias_nt(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB, int output_transpose, int constantA, int constantB, int constantC)
74
+
{
75
+
int broadcast_type_C = 0;
76
+
if (C.dims == 1 && C.w == 1)
77
+
{
78
+
// scalar
79
+
broadcast_type_C = 0;
80
+
}
81
+
if (C.dims == 1 && C.w == M)
82
+
{
83
+
// M
84
+
// auto broadcast from h to w is the ncnn-style convention
85
+
broadcast_type_C = 1;
86
+
}
87
+
if (C.dims == 1 && C.w == N)
88
+
{
89
+
// N
90
+
broadcast_type_C = 4;
91
+
}
92
+
if (C.dims == 2 && C.w == 1 && C.h == M)
93
+
{
94
+
// Mx1
95
+
broadcast_type_C = 2;
96
+
}
97
+
if (C.dims == 2 && C.w == N && C.h == M)
98
+
{
99
+
// MxN
100
+
broadcast_type_C = 3;
101
+
}
102
+
if (C.dims == 2 && C.w == N && C.h == 1)
103
+
{
104
+
// 1xN
105
+
broadcast_type_C = 4;
106
+
}
107
+
108
+
ncnn::ParamDict pd;
109
+
pd.set(0, alpha);
110
+
pd.set(1, beta);
111
+
pd.set(2, transA);
112
+
pd.set(3, transB);
113
+
pd.set(4, constantA);
114
+
pd.set(5, constantB);
115
+
pd.set(6, constantC);
116
+
pd.set(7, M);
117
+
pd.set(8, N);
118
+
pd.set(9, K);
119
+
pd.set(10, broadcast_type_C);
120
+
pd.set(14, output_transpose);
121
+
122
+
std::vector<ncnn::Mat> weights;
123
+
if (constantA) weights.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M));
124
+
if (constantB) weights.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K));
125
+
if (constantC) weights.push_back(C);
126
+
127
+
std::vector<ncnn::Mat> a;
128
+
if (!constantA) a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M));
129
+
if (!constantB) a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K));
130
+
if (!constantC) a.push_back(C);
131
+
132
+
for (size_t i = 0; i < weights.size(); i++)
133
+
{
134
+
Randomize(weights[i]);
135
+
}
136
+
137
+
for (size_t i = 0; i < a.size(); i++)
138
+
{
139
+
Randomize(a[i]);
140
+
}
141
+
142
+
float epsilon = 0.001;
143
+
144
+
int ret = test_layer("Gemm", pd, weights, a, 1, epsilon, TEST_LAYER_ENABLE_THREADING);
0 commit comments