Skip to content

Commit e6e048e

Browse files
authored
KFI-203 Improve thread safety of packing in convolve_kleidiai.cpp (#26575)
### Description Making cache objects of packed data thread_local rather than static. ### Motivation and Context Both LHS and RHS packing utilize a cache mechanism based on a static unordered map. There's the potential for interference between parallel inference sessions. Made both structures thread_local. Signed-off-by: Colm Donelan <[email protected]>
1 parent 8e951ef commit e6e048e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ static std::shared_ptr<std::byte[]> RhsPackWeightsBiasSme(const size_t co, const
332332
const float* weights, const float* bias,
333333
MLAS_THREADPOOL* ThreadPool)
334334
{
335-
//cache of prepacked kai rhs weights and biases
336-
static std::unordered_map<RhsCacheKey, std::shared_ptr<std::byte[]>> rhs_cache;
335+
// Cache of prepacked kai rhs weights and biases. thread_local to prevent interference from parallel sessions.
336+
thread_local std::unordered_map<RhsCacheKey, std::shared_ptr<std::byte[]>> rhs_cache;
337337

338338
RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) };
339339

@@ -474,8 +474,8 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
474474

475475
auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool);
476476

477-
//cache of computed lhs ptr offsets
478-
static std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>> lhs_ptrs_cache;
477+
// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
478+
thread_local std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>> lhs_ptrs_cache;
479479

480480
std::shared_ptr<const void*[]> lhs_ptrs;
481481
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {

0 commit comments

Comments
 (0)