1212#include < functional>
1313#include < unordered_map>
1414
15+ #include " kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
1516#include " kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
1617#include " kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h"
1718#include " kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
@@ -161,24 +162,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
161162 return false ;
162163 }
163164
164- // optimization checks - is the implementation optimal for the conv request
165-
166- const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ();
167- const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ();
168-
169- auto M = ComputeConvOutSize (Parameters->InputShape [0 ], ComputeKernelSize (Parameters->DilationShape [0 ],
170- Parameters->KernelShape [0 ]), Parameters->Padding [0 ], Parameters->StrideShape [0 ]) *
171- ComputeConvOutSize (Parameters->InputShape [1 ], ComputeKernelSize (Parameters->DilationShape [1 ],
172- Parameters->KernelShape [1 ]), Parameters->Padding [1 ], Parameters->StrideShape [1 ]);
173165 auto N = Parameters->FilterCount ;
174- auto K = Parameters->InputChannels * Parameters->KernelShape [0 ] * Parameters->KernelShape [1 ];
175-
176- // Can use these variables to add other conditions as required
177- MLAS_UNREFERENCED_PARAMETER (M);
178- MLAS_UNREFERENCED_PARAMETER (K);
179- MLAS_UNREFERENCED_PARAMETER (m_step);
180- MLAS_UNREFERENCED_PARAMETER (n_step);
181-
182166 if (N == 1 || Parameters->KernelShape [0 ] < 3 || Parameters->KernelShape [1 ] < 3 ) {
183167 KLEIDIAI_DEBUG_LOG (" CheckCapabilitiesSme returning false on optimization checks." );
184168 return false ;
@@ -314,8 +298,8 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
314298 const size_t kw, const void * const * lhs_ptrs, std::byte* lhs_data,
315299 const float * in_data,
316300 const float * pad_ptr) {
317-
318- auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ();
301+ size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
302+ : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
319303
320304 // Minimize the kernel call count for the number of available threads
321305 auto RequiredTiles = MlasDivRoundup (m, m_step);
@@ -399,7 +383,9 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
399383
400384 const auto m = ComputeConvOutSize (ih, kh, padding, sh) * ComputeConvOutSize (iw, kw, padding, sw);
401385
402- const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ();
386+ const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
387+ : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
388+
403389 const auto lhs_ptrs_k = kh * kw;
404390 const auto lhs_ptrs_m = m_step * MlasDivRoundup (m, m_step);
405391 auto lhs_ptrs = std::shared_ptr<const void *[]>(new const void *[lhs_ptrs_k * lhs_ptrs_m],
@@ -505,13 +491,13 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
505491}
506492
507493static void ConvolveSme (const size_t co, // channels out
508- const size_t ci, // channels in
509- const size_t ih, // image height
510- const size_t iw, // image width
511- const size_t kh, // kernel height
512- const size_t kw, // kernel width
513- const size_t sh, // kernel stride height
514- const size_t sw, // kernel stride width
494+ const size_t ci, // channels in
495+ const size_t ih, // image height
496+ const size_t iw, // image width
497+ const size_t kh, // kernel height
498+ const size_t kw, // kernel width
499+ const size_t sh, // kernel stride height
500+ const size_t sw, // kernel stride width
515501 const size_t dilationh, // kernel dilation stride
516502 const size_t dilationw, // kernel dilation stride
517503 const size_t padding, // padding size
@@ -532,10 +518,12 @@ static void ConvolveSme(const size_t co, //channels out
532518 const auto m = ComputeConvOutSize (ih, d_kh, padding, sh) *
533519 ComputeConvOutSize (iw, d_kw, padding, sw);
534520
535- auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ();
536- auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ();
521+ size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
522+ : kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
523+ size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa ()
524+ : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa ();
537525
538- // tile iteration dimensions
526+ // tile iteration dimensions
539527 std::array<size_t ,3 > dim;
540528 dim[0 ] = 1 ; // B
541529 dim[1 ] = MlasDivRoundup (m, m_step); // M
@@ -571,29 +559,23 @@ static void ConvolveSme(const size_t co, //channels out
571559 auto lhs = LhsPackImageDataSme (ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool);
572560 auto rhs = RhsPackWeightsBiasSme (co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool);
573561
574-
575- MlasTrySimpleParallel (ThreadPool,
576- static_cast <ptrdiff_t >(dim[0 ]*dim[1 ]*dim[2 ]),
577- [&](ptrdiff_t tid)
578- {
562+ MlasTrySimpleParallel (ThreadPool, static_cast <ptrdiff_t >(dim[0 ] * dim[1 ] * dim[2 ]), [&](ptrdiff_t tid) {
579563 // compute B,M,N index from iteration index
580564 // ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
581565 ptrdiff_t MIdx = (tid % (dim[1 ] * dim[2 ])) / dim[2 ];
582566 ptrdiff_t NIdx = (tid % (dim[1 ] * dim[2 ])) % dim[2 ];
583567
584568 // Get rhs tile, B
585- const size_t rhs_packed_offset =
586- kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (NIdx*n_step,
587- d_kh*d_kw,ci);
569+ const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (NIdx * n_step, d_kh * d_kw, ci)
570+ : kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa (NIdx * n_step, d_kh * d_kw, ci);
588571
589572 auto BTile = reinterpret_cast <const void *>(
590573 reinterpret_cast <const std::byte*>(rhs.get ()) + rhs_packed_offset
591574 );
592575
593576 // Get lhs tile, A
594- const size_t lhs_packed_offset =
595- kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (MIdx*m_step,
596- d_kh*d_kw,ci);
577+ const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (MIdx * m_step, d_kh * d_kw, ci)
578+ : kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa (MIdx * m_step, d_kh * d_kw, ci);
597579
598580 auto ATile = reinterpret_cast <const float *>(
599581 reinterpret_cast <const std::byte*>(lhs.get ()) + lhs_packed_offset
@@ -607,12 +589,19 @@ static void ConvolveSme(const size_t co, //channels out
607589 MIdx * m_step * co * sizeof (float ) +
608590 NIdx * n_step * sizeof (float )];
609591
610- KLEIDIAI_KERNEL_LOG (" kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa"
611- << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci);
612- kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (
613- TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof (float ),
614- -std::numeric_limits<float >::max (), std::numeric_limits<float >::max ()
615- );
592+ if (ArmKleidiAI::UseSME2) {
593+ KLEIDIAI_KERNEL_LOG (" kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
594+ kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa (
595+ TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof (float ),
596+ -std::numeric_limits<float >::max (), std::numeric_limits<float >::max ()
597+ );
598+ } else {
599+ KLEIDIAI_KERNEL_LOG (" kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
600+ kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa (
601+ TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof (float ),
602+ -std::numeric_limits<float >::max (), std::numeric_limits<float >::max ()
603+ );
604+ }
616605 });
617606
618607 if (result == tmp_mlas_aligned) {
@@ -712,11 +701,11 @@ ArmKleidiAI::MlasConv(
712701 )
713702{
714703 if (!CheckCapabilitiesSme (Parameters)){
715- // Fallback to Default Mlas
704+ // Fallback to Default Mlas
716705 return false ;
717706 };
718707 ConvolveSme (Parameters->FilterCount , Parameters->InputChannels , // channel out, in
719- Parameters->InputShape [0 ], Parameters->InputShape [1 ], // image dimensions
708+ Parameters->InputShape [0 ], Parameters->InputShape [1 ], // image dimensions
720709 Parameters->KernelShape [0 ], Parameters->KernelShape [1 ], // kernel dimensions
721710 Parameters->StrideShape [0 ], Parameters->StrideShape [1 ], // kernel stride dimensions
722711 Parameters->DilationShape [0 ], Parameters->DilationShape [1 ], // kernel dilation
0 commit comments