Skip to content

Commit 8fe4804

Browse files
Adding SME1 Convolution Kernel to convole_kleidiai.cpp (#26402)
### Description - Integration of SME1 Variant of existing SME2 convolution Kernel, kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa and associated packing functions - Formatting changes in convolve_kleidiai.cpp - Addition of proper sme2 gate for dynamic qgemm - Updating of kleidiai version to 1.14 (first version which contains the appropriate kernel) --------- Signed-off-by: Jonathan Clohessy <[email protected]> Co-authored-by: Colm Donelan <[email protected]>
1 parent d6219b6 commit 8fe4804

File tree

6 files changed

+50
-62
lines changed

6 files changed

+50
-62
lines changed

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
5656
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
5757
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
5858
dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a
59-
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a
59+
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2
6060
duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME()
4+
#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME2()
55
#include "core/common/narrow.h"
66
#include "core/common/safeint.h"
77
#include "core/mlas/inc/mlas.h"
@@ -213,9 +213,9 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
213213
}
214214
}
215215

216-
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
216+
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
217217
// We check that here too before attempting to use them.
218-
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) {
218+
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME2()) {
219219
can_use_dynamic_quant_mlas_ = false;
220220
}
221221

onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <functional>
1313
#include <unordered_map>
1414

15+
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
1516
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
1617
#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h"
1718
#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
@@ -161,24 +162,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
161162
return false;
162163
}
163164

164-
//optimization checks - is the implementation optimal for the conv request
165-
166-
const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
167-
const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
168-
169-
auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0],
170-
Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) *
171-
ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1],
172-
Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]);
173165
auto N = Parameters->FilterCount;
174-
auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1];
175-
176-
//Can use these variables to add other conditions as required
177-
MLAS_UNREFERENCED_PARAMETER(M);
178-
MLAS_UNREFERENCED_PARAMETER(K);
179-
MLAS_UNREFERENCED_PARAMETER(m_step);
180-
MLAS_UNREFERENCED_PARAMETER(n_step);
181-
182166
if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) {
183167
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on optimization checks.");
184168
return false;
@@ -314,8 +298,8 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
314298
const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data,
315299
const float* in_data,
316300
const float* pad_ptr) {
317-
318-
auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
301+
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
302+
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
319303

320304
// Minimize the kernel call count for the number of available threads
321305
auto RequiredTiles = MlasDivRoundup(m, m_step);
@@ -399,7 +383,9 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
399383

400384
const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw);
401385

402-
const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
386+
const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
387+
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
388+
403389
const auto lhs_ptrs_k = kh * kw;
404390
const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step);
405391
auto lhs_ptrs = std::shared_ptr<const void*[]>(new const void*[lhs_ptrs_k * lhs_ptrs_m],
@@ -505,13 +491,13 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
505491
}
506492

507493
static void ConvolveSme(const size_t co, //channels out
508-
const size_t ci, //channels in
509-
const size_t ih, //image height
510-
const size_t iw, //image width
511-
const size_t kh, //kernel height
512-
const size_t kw, //kernel width
513-
const size_t sh, //kernel stride height
514-
const size_t sw, //kernel stride width
494+
const size_t ci, //channels in
495+
const size_t ih, //image height
496+
const size_t iw, //image width
497+
const size_t kh, //kernel height
498+
const size_t kw, //kernel width
499+
const size_t sh, //kernel stride height
500+
const size_t sw, //kernel stride width
515501
const size_t dilationh, //kernel dilation stride
516502
const size_t dilationw, //kernel dilation stride
517503
const size_t padding, //padding size
@@ -532,10 +518,12 @@ static void ConvolveSme(const size_t co, //channels out
532518
const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) *
533519
ComputeConvOutSize(iw, d_kw, padding, sw);
534520

535-
auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
536-
auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
521+
size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
522+
: kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
523+
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
524+
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
537525

538-
//tile iteration dimensions
526+
// tile iteration dimensions
539527
std::array<size_t,3> dim;
540528
dim[0] = 1; // B
541529
dim[1] = MlasDivRoundup(m, m_step); // M
@@ -571,29 +559,23 @@ static void ConvolveSme(const size_t co, //channels out
571559
auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool);
572560
auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool);
573561

574-
575-
MlasTrySimpleParallel(ThreadPool,
576-
static_cast<ptrdiff_t>(dim[0]*dim[1]*dim[2]),
577-
[&](ptrdiff_t tid)
578-
{
562+
MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [&](ptrdiff_t tid) {
579563
//compute B,M,N index from iteration index
580564
//ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
581565
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
582566
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];
583567

584568
// Get rhs tile, B
585-
const size_t rhs_packed_offset =
586-
kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step,
587-
d_kh*d_kw,ci);
569+
const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci)
570+
: kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci);
588571

589572
auto BTile = reinterpret_cast<const void*>(
590573
reinterpret_cast<const std::byte*>(rhs.get()) + rhs_packed_offset
591574
);
592575

593576
// Get lhs tile, A
594-
const size_t lhs_packed_offset =
595-
kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step,
596-
d_kh*d_kw,ci);
577+
const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci)
578+
: kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci);
597579

598580
auto ATile = reinterpret_cast<const float*>(
599581
reinterpret_cast<const std::byte*>(lhs.get()) + lhs_packed_offset
@@ -607,12 +589,19 @@ static void ConvolveSme(const size_t co, //channels out
607589
MIdx * m_step * co * sizeof(float) +
608590
NIdx * n_step * sizeof(float)];
609591

610-
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa"
611-
<< " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci);
612-
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
613-
TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
614-
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
615-
);
592+
if (ArmKleidiAI::UseSME2) {
593+
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
594+
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
595+
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
596+
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
597+
);
598+
} else {
599+
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
600+
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
601+
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
602+
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
603+
);
604+
}
616605
});
617606

618607
if (result == tmp_mlas_aligned) {
@@ -712,11 +701,11 @@ ArmKleidiAI::MlasConv(
712701
)
713702
{
714703
if(!CheckCapabilitiesSme(Parameters)){
715-
//Fallback to Default Mlas
704+
// Fallback to Default Mlas
716705
return false;
717706
};
718707
ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in
719-
Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions
708+
Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions
720709
Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions
721710
Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions
722711
Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#pragma once
88

9-
#include "mlasi.h"
9+
#include "../mlasi.h"
1010
#include <iostream>
1111

1212
// Fix to ensure compatibility with MSVC build
@@ -50,13 +50,12 @@
5050
#endif
5151

5252
namespace ArmKleidiAI {
53+
5354
// By default we should try for SME2 first before falling back to SME.
5455
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
5556

56-
//
5757
// Buffer packing routines.
5858
//
59-
6059
size_t
6160
MLASCALL
6261
MlasGemmPackBSize(

onnxruntime/core/mlas/lib/qgemm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ MlasDynamicQGemmBatch (
210210
MLAS_THREADPOOL* ThreadPool
211211
) {
212212
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
213-
//No fallback and putting in guards
214-
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
215-
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
213+
//No fallback and putting in guards. This implementation is SME2 specific.
214+
if(ArmKleidiAI::UseSME2){
215+
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
216216
}
217217
#endif
218218

onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class MlasDynamicQgemmTest {
2020

2121
public:
2222
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
23-
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
24-
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) {
25-
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test.";
23+
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
24+
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
25+
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME2 but it was not detected. Skipping test.";
2626
}
2727

2828
// Setup buffers for holding various data

0 commit comments

Comments
 (0)