Skip to content

Commit 43f9c2f

Browse files
dsharletgxnnpack-bot
authored andcommitted
Add test coverage of YNN_FLAG_CONSISTENT_ARITHMETIC
This doesn't test for consistency (which would be difficult to do), but we should test the correctness of the dot subgraph code with this flag set. Currently, there are few if any branches in the code based on this flag, but that will be changing soon. This uncovered a bug: we didn't mark any arm kernels as consistent. PiperOrigin-RevId: 835645726
1 parent 81f87bd commit 43f9c2f

File tree

5 files changed

+50
-27
lines changed

5 files changed

+50
-27
lines changed

ynnpack/kernels/dot/generator/arm_bf16_bf16_fp32.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def __init__(self):
1111
super().__init__("neon", "bf16_bf16_fp32", "float", (1, 4, 1))
1212
self.a_type = "bfloat16"
1313
self.b_type = "bfloat16"
14+
self.flags += ["dot_flag::consistent_arithmetic"]
1415

1516
def header(self):
1617
return super().header() + """

ynnpack/kernels/dot/generator/arm_fp32.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def __init__(self):
88
super().__init__("neon", "fp32", "float", (1, 4, 1))
99
self.a_type = "float"
1010
self.b_type = "float"
11+
self.flags += ["dot_flag::consistent_arithmetic"]
1112

1213
def header(self):
1314
return super().header() + """

ynnpack/kernels/dot/generator/arm_int8_int8_int32.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def __init__(self, arch, tile_shape):
88
super().__init__(arch, "int8_int8_int32", "int32_t", tile_shape)
99
self.a_type = "int8_t"
1010
self.b_type = "int8_t"
11+
self.flags += ["dot_flag::consistent_arithmetic"]
1112

1213

1314
class arm_neon_int8_int8_int32(arm_int8_int8_int32):

ynnpack/kernels/dot/kernels.inc

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -67,33 +67,41 @@ YNN_DOT_KERNEL(arch_flag::amxint8, dot_uint8_int8_int32_16x64x64_16x16x4_amxint8
6767

6868
#ifndef YNN_DISABLE_SME
6969
#ifdef YNN_ARCH_ARM64_SME2
70-
YNN_DOT_KERNEL(arch_flag::sme2, dot_fp32_sme2, sme_vl(float{}),
71-
4 * sme_vl(float{}), 1, sme_vl(float{}), 1,
72-
dot_flag::transpose_a, float, float, float)
73-
YNN_DOT_KERNEL(arch_flag::sme2, dot_bf16_bf16_fp32_sme2, sme_vl(float{}),
74-
4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
75-
dot_flag::transpose_a, bfloat16, bfloat16, float)
76-
YNN_DOT_KERNEL(arch_flag::sme2, dot_fp16_fp16_fp32_sme2, sme_vl(float{}),
77-
4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
78-
dot_flag::transpose_a, half, half, float)
79-
YNN_DOT_KERNEL(arch_flag::sme2, dot_int8_int8_int32_sme2, sme_vl(int32_t{}),
80-
4 * sme_vl(int32_t{}), 4, sme_vl(int32_t{}), 4,
81-
dot_flag::transpose_a, int8_t, int8_t, int32_t)
70+
YNN_DOT_KERNEL(arch_flag::sme2, dot_fp32_sme2,
71+
sme_vl(float{}), 4 * sme_vl(float{}), 1, sme_vl(float{}), 1,
72+
dot_flag::transpose_a | dot_flag::consistent_arithmetic,
73+
float, float, float)
74+
YNN_DOT_KERNEL(arch_flag::sme2, dot_bf16_bf16_fp32_sme2,
75+
sme_vl(float{}), 4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
76+
dot_flag::transpose_a,
77+
bfloat16, bfloat16, float)
78+
YNN_DOT_KERNEL(arch_flag::sme2, dot_fp16_fp16_fp32_sme2,
79+
sme_vl(float{}), 4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
80+
dot_flag::transpose_a,
81+
half, half, float)
82+
YNN_DOT_KERNEL(arch_flag::sme2, dot_int8_int8_int32_sme2,
83+
sme_vl(int32_t{}), 4 * sme_vl(int32_t{}), 4, sme_vl(int32_t{}), 4,
84+
dot_flag::transpose_a | dot_flag::consistent_arithmetic,
85+
int8_t, int8_t, int32_t)
8286
#endif // YNN_ARCH_ARM64_SME2
8387

8488
#ifdef YNN_ARCH_ARM64_SME
85-
YNN_DOT_KERNEL(arch_flag::sme, dot_fp32_sme, sme_vl(float{}),
86-
4 * sme_vl(float{}), 1, sme_vl(float{}), 1,
87-
dot_flag::transpose_a, float, float, float)
88-
YNN_DOT_KERNEL(arch_flag::sme, dot_bf16_bf16_fp32_sme, sme_vl(float{}),
89-
4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
90-
dot_flag::transpose_a, bfloat16, bfloat16, float)
91-
YNN_DOT_KERNEL(arch_flag::sme, dot_fp16_fp16_fp32_sme, sme_vl(float{}),
92-
4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
93-
dot_flag::transpose_a, half, half, float)
94-
YNN_DOT_KERNEL(arch_flag::sme, dot_int8_int8_int32_sme, sme_vl(int32_t{}),
95-
4 * sme_vl(int32_t{}), 4, sme_vl(int32_t{}), 4,
96-
dot_flag::transpose_a, int8_t, int8_t, int32_t)
89+
YNN_DOT_KERNEL(arch_flag::sme, dot_fp32_sme,
90+
sme_vl(float{}), 4 * sme_vl(float{}), 1, sme_vl(float{}), 1,
91+
dot_flag::transpose_a | dot_flag::consistent_arithmetic,
92+
float, float, float)
93+
YNN_DOT_KERNEL(arch_flag::sme, dot_bf16_bf16_fp32_sme,
94+
sme_vl(float{}), 4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
95+
dot_flag::transpose_a,
96+
bfloat16, bfloat16, float)
97+
YNN_DOT_KERNEL(arch_flag::sme, dot_fp16_fp16_fp32_sme,
98+
sme_vl(float{}), 4 * sme_vl(float{}), 2, sme_vl(float{}), 2,
99+
dot_flag::transpose_a,
100+
half, half, float)
101+
YNN_DOT_KERNEL(arch_flag::sme, dot_int8_int8_int32_sme,
102+
sme_vl(int32_t{}), 4 * sme_vl(int32_t{}), 4, sme_vl(int32_t{}), 4,
103+
dot_flag::transpose_a | dot_flag::consistent_arithmetic,
104+
int8_t, int8_t, int32_t)
97105
#endif // YNN_ARCH_ARM64_SME
98106
#endif // YNN_DISABLE_SME
99107

ynnpack/subgraph/test/dot.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,11 @@ void TestStaticB(A, B, C) {
157157
Tensor<B> b(to_physical_shape<B>(b_shape));
158158
b.generate([&]() { return b_gen(rng); });
159159

160-
SubgraphBuilder subgraph(4);
160+
uint32_t subgraph_flags = 0;
161+
if (random_bool(rng)) {
162+
subgraph_flags |= YNN_FLAG_CONSISTENT_ARITHMETIC;
163+
}
164+
SubgraphBuilder subgraph(4, subgraph_flags);
161165
const uint32_t a_id = 0;
162166
const uint32_t b_id = 1;
163167
const uint32_t output_id = 3;
@@ -331,7 +335,11 @@ void TestDynamicB(A, B, C) {
331335
std::uniform_int_distribution<size_t>(1, max_k_dims)(rng);
332336
const size_t output_rank = std::max(a_rank, b_rank) - num_k_dims + 1;
333337

334-
SubgraphBuilder subgraph(4);
338+
uint32_t subgraph_flags = 0;
339+
if (random_bool(rng)) {
340+
subgraph_flags |= YNN_FLAG_CONSISTENT_ARITHMETIC;
341+
}
342+
SubgraphBuilder subgraph(4, subgraph_flags);
335343
const uint32_t a_id = 0;
336344
const uint32_t b_id = 1;
337345
const uint32_t output_id = 3;
@@ -512,7 +520,11 @@ void TestStaticShapeDynamicB(A, B, C) {
512520
std::sort(inv_b_perm.begin(), inv_b_perm.end(),
513521
[&](int i, int j) { return b_perm[i] < b_perm[j]; });
514522

515-
SubgraphBuilder subgraph(4);
523+
uint32_t subgraph_flags = 0;
524+
if (random_bool(rng)) {
525+
subgraph_flags |= YNN_FLAG_CONSISTENT_ARITHMETIC;
526+
}
527+
SubgraphBuilder subgraph(4, subgraph_flags);
516528
const uint32_t a_id = 0;
517529
const uint32_t b_id = 1;
518530
const uint32_t output_id = 3;

0 commit comments

Comments
 (0)