Skip to content

Commit 96a6c7f

Browse files
committed
cc
1 parent affba20 commit 96a6c7f

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/layer/riscv/gemm_riscv.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,7 +1537,11 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i
15371537

15381538
Mat topT;
15391539
if (K > TILE_K || broadcast_type_C == 3 || output_transpose)
1540+
{
15401541
topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator);
1542+
if (topT.empty())
1543+
return -100;
1544+
}
15411545

15421546
#pragma omp parallel for num_threads(nT)
15431547
for (int ppi = 0; ppi < nn_M; ppi++)
@@ -1647,7 +1651,11 @@ static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blo
16471651

16481652
Mat topT;
16491653
if (K > TILE_K || broadcast_type_C == 3 || output_transpose)
1654+
{
16501655
topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator);
1656+
if (topT.empty())
1657+
return -100;
1658+
}
16511659

16521660
#pragma omp parallel for num_threads(nT)
16531661
for (int ppi = 0; ppi < nn_M; ppi++)
@@ -1711,7 +1719,11 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo
17111719

17121720
Mat topT;
17131721
if (K > TILE_K || broadcast_type_C == 3 || output_transpose)
1722+
{
17141723
topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator);
1724+
if (topT.empty())
1725+
return -100;
1726+
}
17151727

17161728
#pragma omp parallel for num_threads(nT)
17171729
for (int ppi = 0; ppi < nn_M; ppi++)
@@ -1790,7 +1802,11 @@ static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top
17901802

17911803
Mat topT;
17921804
if (K > TILE_K || broadcast_type_C == 3 || output_transpose)
1805+
{
17931806
topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator);
1807+
if (topT.empty())
1808+
return -100;
1809+
}
17941810

17951811
#pragma omp parallel for num_threads(nT)
17961812
for (int ppi = 0; ppi < nn_M; ppi++)
@@ -1951,6 +1967,8 @@ int Gemm_riscv::create_pipeline(const Option& opt)
19511967
{
19521968
int C_elempack = constantM % packn == 0 ? packn : 1;
19531969
convert_packing(C_data, CT_data, C_elempack, opt);
1970+
if (CT_data.empty())
1971+
return -100;
19541972
}
19551973
#endif // __riscv_vector
19561974

@@ -1959,6 +1977,8 @@ int Gemm_riscv::create_pipeline(const Option& opt)
19591977
{
19601978
Mat C2;
19611979
C2.create_like(CT_data);
1980+
if (C2.empty())
1981+
return -100;
19621982

19631983
const int size = CT_data.total() * CT_data.elempack;
19641984
for (int i = 0; i < size; i++)
@@ -2082,6 +2102,8 @@ int Gemm_riscv::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
20822102
{
20832103
Mat CT_data;
20842104
CT_data.create_like(C, opt.workspace_allocator);
2105+
if (CT_data.empty())
2106+
return -100;
20852107

20862108
const int size = C.total() * C.elempack;
20872109
for (int i = 0; i < size; i++)

tests/test_gemm_oom.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ int main()
387387
return ret;
388388

389389
#if NCNN_INT8
390-
int ret2 = test_gemm_0(M, N, K) || test_gemm_1(M, N, K);
390+
int ret2 = test_gemm_2(M, N, K) || test_gemm_3(M, N, K) || test_gemm_4(M, N, K);
391391
if (ret2 != 0)
392392
return ret2;
393393
#endif

0 commit comments

Comments
 (0)