@@ -13,9 +13,10 @@ Module Name:
1313
1414#pragma once
1515
16- #include " ../mlasi.h"
1716#include < arm_sve.h> // SVE intrinsic header
1817
18+ #include " ../mlasi.h"
19+
1920#ifndef __clang__
2021#pragma GCC push_options
2122#pragma GCC target("arch=armv8.2-a+sve")
@@ -34,132 +35,55 @@ typedef svuint32_t MLAS_SVUINT32;
3435typedef svbool_t MLAS_SVBOOL;
3536
3637// function decarations
37- MLAS_FORCEINLINE
38- MLAS_SVFLOAT32
39- MlasSveComputeExpVector (
40- MLAS_SVBOOL Pred,
41- MLAS_SVFLOAT32 Vector
42- );
38+ MLAS_FORCEINLINE MLAS_SVFLOAT32
39+ MlasSveComputeExpVector (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector);
4340
44- void
45- MLASCALL
46- MlasSveComputeExpF32Kernel (
47- const float * Input,
48- float * Output,
49- size_t N
50- );
41+ void MLASCALL
42+ MlasSveComputeExpF32Kernel (const float * Input, float * Output, size_t N);
5143
5244MLAS_FORCEINLINE
5345MLAS_SVFLOAT32
54- MlasSveComputeSumExpVector (
55- MLAS_SVBOOL Pred,
56- MLAS_SVFLOAT32 Vector,
57- MLAS_SVFLOAT32 NegativeMaximumVector
58- );
46+ MlasSveComputeSumExpVector (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector, MLAS_SVFLOAT32 NegativeMaximumVector);
5947
60- float
61- MLASCALL
62- MlasSveComputeSumExpF32Kernel (
63- const float * Input,
64- float * Output,
65- size_t N,
66- const float * NegativeMaximum
67- );
48+ float MLASCALL
49+ MlasSveComputeSumExpF32Kernel (const float * Input, float * Output, size_t N, const float * NegativeMaximum);
6850
6951float MLASCALL
70- MlasSveReduceMaximumF32Kernel (
71- const float * Input,
72- size_t N
73- );
52+ MlasSveReduceMaximumF32Kernel (const float * Input, size_t N);
7453
75- void
76- MLASCALL
77- MlasSveReduceMinimumMaximumF32Kernel (
78- const float * Input,
79- float * Min,
80- float * Max,
81- size_t N
82- );
54+ void MLASCALL
55+ MlasSveReduceMinimumMaximumF32Kernel (const float * Input, float * Min, float * Max, size_t N);
8356
84- void
85- MLASCALL
86- MlasSveComputeSoftmaxOutputF32Kernel (
87- float * Output,
88- size_t N,
89- const float * Parameters
90- );
57+ void MLASCALL
58+ MlasSveComputeSoftmaxOutputF32Kernel (float * Output, size_t N, const float * Parameters);
9159
92- void
93- MLASCALL
94- MlasSveComputeLogSoftmaxOutputF32Kernel (
95- const float * Input,
96- float * Output,
97- size_t N,
98- const float * Parameters
99- );
60+ void MLASCALL
61+ MlasSveComputeLogSoftmaxOutputF32Kernel (const float * Input, float * Output, size_t N, const float * Parameters);
10062
101- void
102- MLASCALL
103- MlasSveErfKernel (
104- const float * Input,
105- float * Output,
106- size_t N
107- );
63+ void MLASCALL
64+ MlasSveErfKernel (const float * Input, float * Output, size_t N);
10865
109- void
110- MLASCALL
111- MlasSveLogisticKernel (
112- const float * Input,
113- float * Output,
114- size_t N
115- );
66+ void MLASCALL
67+ MlasSveLogisticKernel (const float * Input, float * Output, size_t N);
11668
117- // MLAS API for SVE intrinsics
69+ // MLAS API for SVE intrinsics
11870size_t MLASCALL
119- MlasSgemmKernelAdd_sve (
120- const float * A,
121- const float * B,
122- float * C,
123- size_t CountK,
124- size_t CountM,
125- size_t CountN,
126- size_t lda,
127- size_t ldc,
128- float alpha
129- );
71+ MlasSgemmKernelAdd_sve (const float * A, const float * B, float * C, size_t CountK, size_t CountM, size_t CountN, size_t lda, size_t ldc, float alpha);
13072
13173size_t MLASCALL
132- MlasSgemmKernelZero_sve (
133- const float * A,
134- const float * B,
135- float * C,
136- size_t CountK,
137- size_t CountM,
138- size_t CountN,
139- size_t lda,
140- size_t ldc,
141- float alpha
142- );
143-
144- void MLAS_SVE_TARGET
145- MLASCALL
146- SVE_ZERO_INITIALIZE (float * d);
147-
148- void MLAS_SVE_TARGET
149- MLASCALL
150- SVE_LOAD_STORE (float * D, const float * b);
74+ MlasSgemmKernelZero_sve (const float * A, const float * B, float * C, size_t CountK, size_t CountM, size_t CountN, size_t lda, size_t ldc, float alpha);
75+
76+ void MLAS_SVE_TARGET MLASCALL
77+ SVE_ZERO_INITIALIZE (float * d);
78+
79+ void MLAS_SVE_TARGET MLASCALL
80+ SVE_LOAD_STORE (float * D, const float * b);
15181
15282void MLAS_SVE_TARGET MLASCALL
15383SCATTER_STORE (float * d, const float * b);
15484
155- void MLAS_SVE_TARGET
156- MLASCALL
157- SVE_TRANSPOSE (
158- float *& D,
159- const float *& b,
160- size_t ldb,
161- size_t & x
162- );
85+ void MLAS_SVE_TARGET MLASCALL
86+ SVE_TRANSPOSE (float *& D, const float *& b, size_t ldb, size_t & x);
16387
16488MLAS_SVE_TARGET
16589inline int
@@ -240,8 +164,8 @@ MLAS_SVE_TARGET
240164MLAS_FORCEINLINE
241165MLAS_SVINT32
242166MlasSveAddInt32 (MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2)
243- {
244- return svadd_s32_m (Pred, Vector1, Vector2);
167+ {
168+ return svadd_s32_m (Pred, Vector1, Vector2);
245169}
246170
247171MLAS_SVE_TARGET
@@ -298,26 +222,26 @@ MLAS_SVINT32
298222MlasSveBlendInt32 (MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2, MLAS_SVINT32 Selection)
299223{
300224 return MlasSveOrInt32 (
301- Pred,
302- MlasSveAndInt32 (Pred, Vector2, Selection),
225+ Pred,
226+ MlasSveAndInt32 (Pred, Vector2, Selection),
303227 MlasSveAndNotInt32 (Pred, Selection, Vector1)
304228 );
305229}
306230
307- template <unsigned ShiftCount>
231+ template <unsigned ShiftCount>
308232MLAS_SVE_TARGET
309- MLAS_FORCEINLINE
310- MLAS_SVUINT32
311- MlasSveShiftLeftUInt32 (MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector)
233+ MLAS_FORCEINLINE
234+ MLAS_SVUINT32
235+ MlasSveShiftLeftUInt32 (MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector)
312236{
313237 return svlsl_n_u32_z (Pred, Vector, ShiftCount);
314238}
315239
316- template <unsigned ShiftCount>
240+ template <unsigned ShiftCount>
317241MLAS_SVE_TARGET
318- MLAS_FORCEINLINE
319- MLAS_SVINT32
320- MlasSveShiftLeftInt32 (MLAS_SVBOOL Pred, MLAS_SVINT32 Vector)
242+ MLAS_FORCEINLINE
243+ MLAS_SVINT32
244+ MlasSveShiftLeftInt32 (MLAS_SVBOOL Pred, MLAS_SVINT32 Vector)
321245{
322246 return svlsl_n_s32_z (Pred, Vector, ShiftCount);
323247}
@@ -402,11 +326,10 @@ MlasSveStoreFloat32(MLAS_SVBOOL Pred, float* Buffer, MLAS_SVFLOAT32 Vector)
402326 svst1_f32 (Pred, Buffer, Vector);
403327}
404328
405- template <unsigned Lane>
329+ template <unsigned Lane>
406330MLAS_SVE_TARGET
407- MLAS_FORCEINLINE
408- void
409- MlasSveStoreLaneFloat32 (float * Buffer, MLAS_SVFLOAT32 Vector)
331+ MLAS_FORCEINLINE void
332+ MlasSveStoreLaneFloat32 (float * Buffer, MLAS_SVFLOAT32 Vector)
410333{
411334 svbool_t Pred = svwhilelt_b32 (Lane, Lane + 1 );
412335 svst1_f32 (Pred, Buffer, Vector);
@@ -421,11 +344,10 @@ MlasSveStoreLowHalfFloat32(float* Buffer, MLAS_SVFLOAT32 Vector)
421344 svst1_f32 (Pred, Buffer, Vector);
422345}
423346
424- template <unsigned Lane>
347+ template <unsigned Lane>
425348MLAS_SVE_TARGET
426- MLAS_FORCEINLINE
427- float
428- MlasSveExtractLaneFloat32 (MLAS_SVFLOAT32 Vector)
349+ MLAS_FORCEINLINE float
350+ MlasSveExtractLaneFloat32 (MLAS_SVFLOAT32 Vector)
429351{
430352 float TmpBuffer[1 ];
431353 svbool_t Pred = svwhilelt_b32 (Lane, Lane + 1 );
@@ -470,7 +392,7 @@ MLAS_FORCEINLINE
470392MLAS_SVFLOAT32
471393MlasSveMultiplyFloat32 (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2)
472394{
473- return svmul_f32_m (Pred, Vector1, Vector2);
395+ return svmul_f32_m (Pred, Vector1, Vector2);
474396}
475397
476398MLAS_SVE_TARGET
@@ -484,7 +406,7 @@ MlasSveExpFloat32(MLAS_SVUINT32 Vector)
484406MLAS_SVE_TARGET
485407MLAS_FORCEINLINE
486408MLAS_SVFLOAT32
487- MlasSveScaleFloat32 (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2)
409+ MlasSveScaleFloat32 (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2)
488410{
489411 return svscale_f32_m (Pred, Vector1, Vector2);
490412}
@@ -494,7 +416,7 @@ MLAS_FORCEINLINE
494416MLAS_SVFLOAT32
495417MlasSveRoundINTFloat32 (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector)
496418{
497- return svrintm_f32_z (Pred, Vector);
419+ return svrintm_f32_z (Pred, Vector);
498420}
499421
500422MLAS_SVE_TARGET
@@ -537,10 +459,10 @@ MlasSveGreaterThanFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT
537459 // Compare Vector1 and Vector2, return a predicate vector
538460 svbool_t cmp_mask = svcmpgt_f32 (Pred, Vector1, Vector2);
539461
540- // Convert predicate to uint32_t mask
462+ // Convert predicate to uint32_t mask
541463 svuint32_t mask_bits = svdup_u32_z (cmp_mask, 0xFFFFFFFF );
542464
543- // Reinterpret to float32
465+ // Reinterpret to float32
544466 return svreinterpret_f32_u32 (mask_bits);
545467}
546468
@@ -551,7 +473,7 @@ MlasSveAndFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vecto
551473{
552474 return MlasSveReinterpretAsFloat32 (
553475 MlasSveAndInt32 (
554- Pred,
476+ Pred,
555477 MlasSveReinterpretAsInt32 (Vector1),
556478 MlasSveReinterpretAsInt32 (Vector2)
557479 )
@@ -606,7 +528,7 @@ MLAS_SVFLOAT32
606528MlasSveBlendFloat32 (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, MLAS_SVFLOAT32 Selection)
607529{
608530 return MlasSveOrFloat32 (
609- Pred,
531+ Pred,
610532 MlasSveAndFloat32 (Pred, Vector2, Selection),
611533 MlasSveAndFloat32 (Pred, Vector1, Selection)
612534 );
@@ -668,8 +590,8 @@ MLAS_SVFLOAT32
668590MlasSvePowerOf2Float32 (MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector)
669591{
670592 MLAS_SVINT32 emm0 = MlasSveAddInt32 (
671- Pred,
672- MlasSveCastToInt32 (Pred, Vector),
593+ Pred,
594+ MlasSveCastToInt32 (Pred, Vector),
673595 MlasSveBroadcastInt32 (127 )
674596 );
675597 return MlasSveReinterpretAsFloat32 (MlasSveShiftLeftInt32<23 >(Pred, emm0));
@@ -700,8 +622,7 @@ MlasSveCompareLessThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B)
700622}
701623
702624MLASCALL
703- inline
704- void
625+ inline void
705626Transpose_SVE512_4x4 (float * D, const float * B, size_t ldb)
706627{
707628 const static int VL = svcntw ();
@@ -736,8 +657,7 @@ MlasSveCompareGreaterThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B)
736657}
737658
738659MLASCALL
739- inline
740- void
660+ inline void
741661Transpose_SVE256_4x4 (float * D, const float * B, size_t ldb)
742662{
743663 const static int VL = svcntw ();
0 commit comments