@@ -67,33 +67,41 @@ YNN_DOT_KERNEL(arch_flag::amxint8, dot_uint8_int8_int32_16x64x64_16x16x4_amxint8
6767
6868#ifndef YNN_DISABLE_SME
6969#ifdef YNN_ARCH_ARM64_SME2
70- YNN_DOT_KERNEL (arch_flag::sme2, dot_fp32_sme2, sme_vl(float {}),
71- 4 * sme_vl(float {}), 1, sme_vl(float {}), 1,
72- dot_flag::transpose_a, float, float, float)
73- YNN_DOT_KERNEL(arch_flag::sme2, dot_bf16_bf16_fp32_sme2, sme_vl(float {}),
74- 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
75- dot_flag::transpose_a, bfloat16, bfloat16, float)
76- YNN_DOT_KERNEL(arch_flag::sme2, dot_fp16_fp16_fp32_sme2, sme_vl(float {}),
77- 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
78- dot_flag::transpose_a, half, half, float)
79- YNN_DOT_KERNEL(arch_flag::sme2, dot_int8_int8_int32_sme2, sme_vl(int32_t {}),
80- 4 * sme_vl(int32_t {}), 4, sme_vl(int32_t {}), 4,
81- dot_flag::transpose_a, int8_t, int8_t, int32_t)
70+ YNN_DOT_KERNEL (arch_flag::sme2, dot_fp32_sme2,
71+ sme_vl (float {}), 4 * sme_vl(float {}), 1, sme_vl(float {}), 1,
72+ dot_flag::transpose_a | dot_flag::consistent_arithmetic,
73+ float, float, float)
74+ YNN_DOT_KERNEL(arch_flag::sme2, dot_bf16_bf16_fp32_sme2,
75+ sme_vl (float {}), 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
76+ dot_flag::transpose_a,
77+ bfloat16, bfloat16, float)
78+ YNN_DOT_KERNEL(arch_flag::sme2, dot_fp16_fp16_fp32_sme2,
79+ sme_vl (float {}), 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
80+ dot_flag::transpose_a,
81+ half, half, float)
82+ YNN_DOT_KERNEL(arch_flag::sme2, dot_int8_int8_int32_sme2,
83+ sme_vl (int32_t {}), 4 * sme_vl(int32_t {}), 4, sme_vl(int32_t {}), 4,
84+ dot_flag::transpose_a | dot_flag::consistent_arithmetic,
85+ int8_t, int8_t, int32_t)
8286#endif // YNN_ARCH_ARM64_SME2
8387
8488#ifdef YNN_ARCH_ARM64_SME
85- YNN_DOT_KERNEL (arch_flag::sme, dot_fp32_sme, sme_vl(float {}),
86- 4 * sme_vl(float {}), 1, sme_vl(float {}), 1,
87- dot_flag::transpose_a, float, float, float)
88- YNN_DOT_KERNEL(arch_flag::sme, dot_bf16_bf16_fp32_sme, sme_vl(float {}),
89- 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
90- dot_flag::transpose_a, bfloat16, bfloat16, float)
91- YNN_DOT_KERNEL(arch_flag::sme, dot_fp16_fp16_fp32_sme, sme_vl(float {}),
92- 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
93- dot_flag::transpose_a, half, half, float)
94- YNN_DOT_KERNEL(arch_flag::sme, dot_int8_int8_int32_sme, sme_vl(int32_t {}),
95- 4 * sme_vl(int32_t {}), 4, sme_vl(int32_t {}), 4,
96- dot_flag::transpose_a, int8_t, int8_t, int32_t)
89+ YNN_DOT_KERNEL (arch_flag::sme, dot_fp32_sme,
90+ sme_vl (float {}), 4 * sme_vl(float {}), 1, sme_vl(float {}), 1,
91+ dot_flag::transpose_a | dot_flag::consistent_arithmetic,
92+ float, float, float)
93+ YNN_DOT_KERNEL(arch_flag::sme, dot_bf16_bf16_fp32_sme,
94+ sme_vl (float {}), 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
95+ dot_flag::transpose_a,
96+ bfloat16, bfloat16, float)
97+ YNN_DOT_KERNEL(arch_flag::sme, dot_fp16_fp16_fp32_sme,
98+ sme_vl (float {}), 4 * sme_vl(float {}), 2, sme_vl(float {}), 2,
99+ dot_flag::transpose_a,
100+ half, half, float)
101+ YNN_DOT_KERNEL(arch_flag::sme, dot_int8_int8_int32_sme,
102+ sme_vl (int32_t {}), 4 * sme_vl(int32_t {}), 4, sme_vl(int32_t {}), 4,
103+ dot_flag::transpose_a | dot_flag::consistent_arithmetic,
104+ int8_t, int8_t, int32_t)
97105#endif // YNN_ARCH_ARM64_SME
98106#endif // YNN_DISABLE_SME
99107
0 commit comments