Skip to content

Commit db2bfe9

Browse files
committed
code formating
1 parent 0699ab8 commit db2bfe9

File tree

6 files changed

+69
-156
lines changed

6 files changed

+69
-156
lines changed

cmake/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1879,4 +1879,4 @@ endif()
18791879
# Include precompiled header configuration for providers
18801880
if(TARGET onnxruntime_providers)
18811881
include("${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_providers_pch.cmake")
1882-
endif()
1882+
endif()

cmake/onnxruntime_mlas.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,4 +919,4 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD)
919919
endif()
920920
endif()
921921

922-
endif()
922+
endif()

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ class MLASCPUIDInfo
200200
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
201201

202202
bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }
203+
203204
private:
204205
MLASCPUIDInfo();
205206

onnxruntime/core/mlas/lib/sgemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1851,4 +1851,4 @@ Return Value:
18511851

18521852
PackedB = (float*)PackedB + AlignedN * CountK;
18531853
}
1854-
}
1854+
}

onnxruntime/core/mlas/lib/sve/mlasi_sve.h

Lines changed: 60 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ Module Name:
1313

1414
#pragma once
1515

16-
#include "../mlasi.h"
1716
#include <arm_sve.h> // SVE intrinsic header
1817

18+
#include "../mlasi.h"
19+
1920
#ifndef __clang__
2021
#pragma GCC push_options
2122
#pragma GCC target("arch=armv8.2-a+sve")
@@ -34,132 +35,55 @@ typedef svuint32_t MLAS_SVUINT32;
3435
typedef svbool_t MLAS_SVBOOL;
3536

3637
// function decarations
37-
MLAS_FORCEINLINE
38-
MLAS_SVFLOAT32
39-
MlasSveComputeExpVector(
40-
MLAS_SVBOOL Pred,
41-
MLAS_SVFLOAT32 Vector
42-
);
38+
MLAS_FORCEINLINE MLAS_SVFLOAT32
39+
MlasSveComputeExpVector(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector);
4340

44-
void
45-
MLASCALL
46-
MlasSveComputeExpF32Kernel(
47-
const float* Input,
48-
float* Output,
49-
size_t N
50-
);
41+
void MLASCALL
42+
MlasSveComputeExpF32Kernel(const float* Input, float* Output, size_t N);
5143

5244
MLAS_FORCEINLINE
5345
MLAS_SVFLOAT32
54-
MlasSveComputeSumExpVector(
55-
MLAS_SVBOOL Pred,
56-
MLAS_SVFLOAT32 Vector,
57-
MLAS_SVFLOAT32 NegativeMaximumVector
58-
);
46+
MlasSveComputeSumExpVector(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector, MLAS_SVFLOAT32 NegativeMaximumVector);
5947

60-
float
61-
MLASCALL
62-
MlasSveComputeSumExpF32Kernel(
63-
const float* Input,
64-
float* Output,
65-
size_t N,
66-
const float* NegativeMaximum
67-
);
48+
float MLASCALL
49+
MlasSveComputeSumExpF32Kernel(const float* Input, float* Output, size_t N, const float* NegativeMaximum);
6850

6951
float MLASCALL
70-
MlasSveReduceMaximumF32Kernel(
71-
const float* Input,
72-
size_t N
73-
);
52+
MlasSveReduceMaximumF32Kernel(const float* Input, size_t N);
7453

75-
void
76-
MLASCALL
77-
MlasSveReduceMinimumMaximumF32Kernel(
78-
const float* Input,
79-
float* Min,
80-
float* Max,
81-
size_t N
82-
);
54+
void MLASCALL
55+
MlasSveReduceMinimumMaximumF32Kernel(const float* Input, float* Min, float* Max, size_t N);
8356

84-
void
85-
MLASCALL
86-
MlasSveComputeSoftmaxOutputF32Kernel(
87-
float* Output,
88-
size_t N,
89-
const float* Parameters
90-
);
57+
void MLASCALL
58+
MlasSveComputeSoftmaxOutputF32Kernel(float* Output, size_t N, const float* Parameters);
9159

92-
void
93-
MLASCALL
94-
MlasSveComputeLogSoftmaxOutputF32Kernel(
95-
const float* Input,
96-
float* Output,
97-
size_t N,
98-
const float* Parameters
99-
);
60+
void MLASCALL
61+
MlasSveComputeLogSoftmaxOutputF32Kernel(const float* Input, float* Output, size_t N, const float* Parameters);
10062

101-
void
102-
MLASCALL
103-
MlasSveErfKernel(
104-
const float* Input,
105-
float* Output,
106-
size_t N
107-
);
63+
void MLASCALL
64+
MlasSveErfKernel(const float* Input, float* Output, size_t N);
10865

109-
void
110-
MLASCALL
111-
MlasSveLogisticKernel(
112-
const float* Input,
113-
float* Output,
114-
size_t N
115-
);
66+
void MLASCALL
67+
MlasSveLogisticKernel(const float* Input, float* Output, size_t N);
11668

117-
//MLAS API for SVE intrinsics
69+
// MLAS API for SVE intrinsics
11870
size_t MLASCALL
119-
MlasSgemmKernelAdd_sve(
120-
const float* A,
121-
const float* B,
122-
float* C,
123-
size_t CountK,
124-
size_t CountM,
125-
size_t CountN,
126-
size_t lda,
127-
size_t ldc,
128-
float alpha
129-
);
71+
MlasSgemmKernelAdd_sve(const float* A, const float* B, float* C, size_t CountK, size_t CountM, size_t CountN, size_t lda, size_t ldc, float alpha);
13072

13173
size_t MLASCALL
132-
MlasSgemmKernelZero_sve(
133-
const float* A,
134-
const float* B,
135-
float* C,
136-
size_t CountK,
137-
size_t CountM,
138-
size_t CountN,
139-
size_t lda,
140-
size_t ldc,
141-
float alpha
142-
);
143-
144-
void MLAS_SVE_TARGET
145-
MLASCALL
146-
SVE_ZERO_INITIALIZE(float* d);
147-
148-
void MLAS_SVE_TARGET
149-
MLASCALL
150-
SVE_LOAD_STORE(float* D, const float* b);
74+
MlasSgemmKernelZero_sve(const float* A, const float* B, float* C, size_t CountK, size_t CountM, size_t CountN, size_t lda, size_t ldc, float alpha);
75+
76+
void MLAS_SVE_TARGET MLASCALL
77+
SVE_ZERO_INITIALIZE(float* d);
78+
79+
void MLAS_SVE_TARGET MLASCALL
80+
SVE_LOAD_STORE(float* D, const float* b);
15181

15282
void MLAS_SVE_TARGET MLASCALL
15383
SCATTER_STORE(float* d, const float* b);
15484

155-
void MLAS_SVE_TARGET
156-
MLASCALL
157-
SVE_TRANSPOSE(
158-
float*& D,
159-
const float*& b,
160-
size_t ldb,
161-
size_t& x
162-
);
85+
void MLAS_SVE_TARGET MLASCALL
86+
SVE_TRANSPOSE(float*& D, const float*& b, size_t ldb, size_t& x);
16387

16488
MLAS_SVE_TARGET
16589
inline int
@@ -240,8 +164,8 @@ MLAS_SVE_TARGET
240164
MLAS_FORCEINLINE
241165
MLAS_SVINT32
242166
MlasSveAddInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2)
243-
{
244-
return svadd_s32_m(Pred, Vector1, Vector2);
167+
{
168+
return svadd_s32_m(Pred, Vector1, Vector2);
245169
}
246170

247171
MLAS_SVE_TARGET
@@ -298,26 +222,26 @@ MLAS_SVINT32
298222
MlasSveBlendInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2, MLAS_SVINT32 Selection)
299223
{
300224
return MlasSveOrInt32(
301-
Pred,
302-
MlasSveAndInt32(Pred, Vector2, Selection),
225+
Pred,
226+
MlasSveAndInt32(Pred, Vector2, Selection),
303227
MlasSveAndNotInt32(Pred, Selection, Vector1)
304228
);
305229
}
306230

307-
template<unsigned ShiftCount>
231+
template <unsigned ShiftCount>
308232
MLAS_SVE_TARGET
309-
MLAS_FORCEINLINE
310-
MLAS_SVUINT32
311-
MlasSveShiftLeftUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector)
233+
MLAS_FORCEINLINE
234+
MLAS_SVUINT32
235+
MlasSveShiftLeftUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector)
312236
{
313237
return svlsl_n_u32_z(Pred, Vector, ShiftCount);
314238
}
315239

316-
template<unsigned ShiftCount>
240+
template <unsigned ShiftCount>
317241
MLAS_SVE_TARGET
318-
MLAS_FORCEINLINE
319-
MLAS_SVINT32
320-
MlasSveShiftLeftInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector)
242+
MLAS_FORCEINLINE
243+
MLAS_SVINT32
244+
MlasSveShiftLeftInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector)
321245
{
322246
return svlsl_n_s32_z(Pred, Vector, ShiftCount);
323247
}
@@ -402,11 +326,10 @@ MlasSveStoreFloat32(MLAS_SVBOOL Pred, float* Buffer, MLAS_SVFLOAT32 Vector)
402326
svst1_f32(Pred, Buffer, Vector);
403327
}
404328

405-
template<unsigned Lane>
329+
template <unsigned Lane>
406330
MLAS_SVE_TARGET
407-
MLAS_FORCEINLINE
408-
void
409-
MlasSveStoreLaneFloat32(float* Buffer, MLAS_SVFLOAT32 Vector)
331+
MLAS_FORCEINLINE void
332+
MlasSveStoreLaneFloat32(float* Buffer, MLAS_SVFLOAT32 Vector)
410333
{
411334
svbool_t Pred = svwhilelt_b32(Lane, Lane + 1);
412335
svst1_f32(Pred, Buffer, Vector);
@@ -421,11 +344,10 @@ MlasSveStoreLowHalfFloat32(float* Buffer, MLAS_SVFLOAT32 Vector)
421344
svst1_f32(Pred, Buffer, Vector);
422345
}
423346

424-
template<unsigned Lane>
347+
template <unsigned Lane>
425348
MLAS_SVE_TARGET
426-
MLAS_FORCEINLINE
427-
float
428-
MlasSveExtractLaneFloat32(MLAS_SVFLOAT32 Vector)
349+
MLAS_FORCEINLINE float
350+
MlasSveExtractLaneFloat32(MLAS_SVFLOAT32 Vector)
429351
{
430352
float TmpBuffer[1];
431353
svbool_t Pred = svwhilelt_b32(Lane, Lane + 1);
@@ -470,7 +392,7 @@ MLAS_FORCEINLINE
470392
MLAS_SVFLOAT32
471393
MlasSveMultiplyFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2)
472394
{
473-
return svmul_f32_m(Pred, Vector1, Vector2);
395+
return svmul_f32_m(Pred, Vector1, Vector2);
474396
}
475397

476398
MLAS_SVE_TARGET
@@ -484,7 +406,7 @@ MlasSveExpFloat32(MLAS_SVUINT32 Vector)
484406
MLAS_SVE_TARGET
485407
MLAS_FORCEINLINE
486408
MLAS_SVFLOAT32
487-
MlasSveScaleFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2)
409+
MlasSveScaleFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2)
488410
{
489411
return svscale_f32_m(Pred, Vector1, Vector2);
490412
}
@@ -494,7 +416,7 @@ MLAS_FORCEINLINE
494416
MLAS_SVFLOAT32
495417
MlasSveRoundINTFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector)
496418
{
497-
return svrintm_f32_z(Pred, Vector);
419+
return svrintm_f32_z(Pred, Vector);
498420
}
499421

500422
MLAS_SVE_TARGET
@@ -537,10 +459,10 @@ MlasSveGreaterThanFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT
537459
// Compare Vector1 and Vector2, return a predicate vector
538460
svbool_t cmp_mask = svcmpgt_f32(Pred, Vector1, Vector2);
539461

540-
//Convert predicate to uint32_t mask
462+
// Convert predicate to uint32_t mask
541463
svuint32_t mask_bits = svdup_u32_z(cmp_mask, 0xFFFFFFFF);
542464

543-
//Reinterpret to float32
465+
// Reinterpret to float32
544466
return svreinterpret_f32_u32(mask_bits);
545467
}
546468

@@ -551,7 +473,7 @@ MlasSveAndFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vecto
551473
{
552474
return MlasSveReinterpretAsFloat32(
553475
MlasSveAndInt32(
554-
Pred,
476+
Pred,
555477
MlasSveReinterpretAsInt32(Vector1),
556478
MlasSveReinterpretAsInt32(Vector2)
557479
)
@@ -606,7 +528,7 @@ MLAS_SVFLOAT32
606528
MlasSveBlendFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, MLAS_SVFLOAT32 Selection)
607529
{
608530
return MlasSveOrFloat32(
609-
Pred,
531+
Pred,
610532
MlasSveAndFloat32(Pred, Vector2, Selection),
611533
MlasSveAndFloat32(Pred, Vector1, Selection)
612534
);
@@ -668,8 +590,8 @@ MLAS_SVFLOAT32
668590
MlasSvePowerOf2Float32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector)
669591
{
670592
MLAS_SVINT32 emm0 = MlasSveAddInt32(
671-
Pred,
672-
MlasSveCastToInt32(Pred, Vector),
593+
Pred,
594+
MlasSveCastToInt32(Pred, Vector),
673595
MlasSveBroadcastInt32(127)
674596
);
675597
return MlasSveReinterpretAsFloat32(MlasSveShiftLeftInt32<23>(Pred, emm0));
@@ -700,8 +622,7 @@ MlasSveCompareLessThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B)
700622
}
701623

702624
MLASCALL
703-
inline
704-
void
625+
inline void
705626
Transpose_SVE512_4x4(float* D, const float* B, size_t ldb)
706627
{
707628
const static int VL = svcntw();
@@ -736,8 +657,7 @@ MlasSveCompareGreaterThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B)
736657
}
737658

738659
MLASCALL
739-
inline
740-
void
660+
inline void
741661
Transpose_SVE256_4x4(float* D, const float* B, size_t ldb)
742662
{
743663
const static int VL = svcntw();

0 commit comments

Comments
 (0)