Skip to content

Commit 24d8cf8

Browse files
authored
[opt](arm) Improve count_zero_num performance with NEON intrinsics (#58615)
btw fixed the bug of count zero num for nullable column in x86
1 parent 1732416 commit 24d8cf8

File tree

4 files changed

+297
-11
lines changed

4 files changed

+297
-11
lines changed

be/benchmark/benchmark_bits.hpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <benchmark/benchmark.h>
19+
20+
#include "util/simd/bits.h"
21+
22+
namespace doris {} // namespace doris
23+
24+
static void BM_Bits_CountZeroNum(benchmark::State& state) {
25+
const auto n = static_cast<size_t>(state.range(0));
26+
std::vector<int8_t> data(n, 0);
27+
28+
for (auto _ : state) {
29+
auto r = doris::simd::count_zero_num<size_t>(data.data(), data.size());
30+
benchmark::DoNotOptimize(r);
31+
}
32+
33+
state.SetBytesProcessed(state.iterations() * n);
34+
}
35+
36+
static void BM_Bits_CountZeroNumNullMap(benchmark::State& state) {
37+
const auto n = static_cast<size_t>(state.range(0));
38+
std::vector<int8_t> data(n, 0);
39+
std::vector<uint8_t> null_map(n, 0);
40+
41+
for (auto _ : state) {
42+
auto r = doris::simd::count_zero_num<size_t>(data.data(), null_map.data(), data.size());
43+
benchmark::DoNotOptimize(r);
44+
}
45+
46+
state.SetBytesProcessed(state.iterations() * n);
47+
}
48+
49+
BENCHMARK(BM_Bits_CountZeroNum)
50+
->Unit(benchmark::kNanosecond)
51+
->Arg(16) // 16 bytes
52+
->Arg(32) // 32 bytes
53+
->Arg(64) // 64 bytes
54+
->Arg(256) // 256 bytes
55+
->Arg(1024) // 1KB
56+
->Repetitions(5)
57+
->DisplayAggregatesOnly();
58+
59+
BENCHMARK(BM_Bits_CountZeroNumNullMap)
60+
->Unit(benchmark::kNanosecond)
61+
->Arg(16) // 16 bytes
62+
->Arg(32) // 32 bytes
63+
->Arg(64) // 64 bytes
64+
->Arg(256) // 256 bytes
65+
->Arg(1024) // 1KB
66+
->Repetitions(5)
67+
->DisplayAggregatesOnly();

be/benchmark/benchmark_main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <benchmark/benchmark.h>
1919

2020
#include "benchmark_bit_pack.hpp"
21+
#include "benchmark_bits.hpp"
2122
#include "benchmark_block_bloom_filter.hpp"
2223
#include "benchmark_fastunion.hpp"
2324
#include "benchmark_hll_merge.hpp"

be/src/util/simd/bits.h

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <type_traits>
2323
#include <vector>
2424

25-
#if defined(__ARM_NEON) && defined(__aarch64__)
25+
#if defined(__ARM_NEON)
2626
#include <arm_neon.h>
2727
#endif
2828

@@ -130,7 +130,21 @@ template <typename T>
130130
inline T count_zero_num(const int8_t* __restrict data, T size) {
131131
T num = 0;
132132
const int8_t* end = data + size;
133-
#if defined(__SSE2__) && defined(__POPCNT__)
133+
#if defined(__ARM_NEON)
134+
const int8_t* end64 = data + (size / 64 * 64);
135+
136+
for (; data < end64; data += 64) {
137+
auto a0 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data)), 7);
138+
auto a1 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data + 16)), 7);
139+
auto a2 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data + 32)), 7);
140+
auto a3 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data + 48)), 7);
141+
142+
auto s0 = vaddq_u8(a0, a1);
143+
auto s1 = vaddq_u8(a2, a3);
144+
auto s = vaddq_u8(s0, s1);
145+
num += vaddvq_u8(s);
146+
}
147+
#elif defined(__SSE2__) && defined(__POPCNT__)
134148
const __m128i zero16 = _mm_setzero_si128();
135149
const int8_t* end64 = data + (size / 64 * 64);
136150

@@ -160,34 +174,60 @@ template <typename T>
160174
inline T count_zero_num(const int8_t* __restrict data, const uint8_t* __restrict null_map, T size) {
161175
T num = 0;
162176
const int8_t* end = data + size;
163-
#if defined(__SSE2__) && defined(__POPCNT__)
177+
#if defined(__ARM_NEON)
178+
const int8_t* end64 = data + (size / 64 * 64);
179+
180+
for (; data < end64; data += 64, null_map += 64) {
181+
auto a0 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data)), 7);
182+
auto a1 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data + 16)), 7);
183+
auto a2 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data + 32)), 7);
184+
auto a3 = vshrq_n_u8(vceqzq_s8(vld1q_s8(data + 48)), 7);
185+
186+
auto r0 = vorrq_u8(a0, vld1q_u8(null_map));
187+
auto r1 = vorrq_u8(a1, vld1q_u8(null_map + 16));
188+
auto r2 = vorrq_u8(a2, vld1q_u8(null_map + 32));
189+
auto r3 = vorrq_u8(a3, vld1q_u8(null_map + 48));
190+
191+
auto s0 = vaddq_u8(r0, r1);
192+
auto s1 = vaddq_u8(r2, r3);
193+
auto s = vaddq_u8(s0, s1);
194+
num += vaddvq_u8(s);
195+
}
196+
#elif defined(__SSE2__) && defined(__POPCNT__)
164197
const __m128i zero16 = _mm_setzero_si128();
198+
const __m128i one16 = _mm_set1_epi8(1);
165199
const int8_t* end64 = data + (size / 64 * 64);
166200

167201
for (; data < end64; data += 64, null_map += 64) {
168202
num += __builtin_popcountll(
169203
static_cast<uint64_t>(_mm_movemask_epi8(_mm_or_si128(
170204
_mm_cmpeq_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i*>(data)),
171205
zero16),
172-
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map))))) |
206+
_mm_cmpeq_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map)),
207+
one16)))) |
173208
(static_cast<uint64_t>(_mm_movemask_epi8(_mm_or_si128(
174209
_mm_cmpeq_epi8(
175210
_mm_loadu_si128(reinterpret_cast<const __m128i*>(data + 16)),
176211
zero16),
177-
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map + 16)))))
212+
_mm_cmpeq_epi8(
213+
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map + 16)),
214+
one16))))
178215
<< 16U) |
179216
(static_cast<uint64_t>(_mm_movemask_epi8(_mm_or_si128(
180217
_mm_cmpeq_epi8(
181218
_mm_loadu_si128(reinterpret_cast<const __m128i*>(data + 32)),
182219
zero16),
183-
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map + 32)))))
220+
_mm_cmpeq_epi8(
221+
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map + 32)),
222+
one16))))
184223
<< 32U) |
185224
(static_cast<uint64_t>(_mm_movemask_epi8(_mm_or_si128(
186-
_mm_cmpeq_epi8(
187-
_mm_loadu_si128(reinterpret_cast<const __m128i*>(data + 48)),
188-
zero16),
189-
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map + 48)))))
190-
<< 48U));
225+
_mm_cmpeq_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i*>(data + 48)),
226+
zero16),
227+
_mm_cmpeq_epi8(
228+
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map + 48)),
229+
one16)))))
230+
<< 48U);
191231
}
192232
#endif
193233
for (; data < end; ++data, ++null_map) {

be/test/util/simd/bits_test.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "util/simd/bits.h"
19+
20+
#include <gtest/gtest-message.h>
21+
#include <gtest/gtest-test-part.h>
22+
#include <gtest/gtest.h>
23+
24+
#include <algorithm>
25+
26+
#include "gtest/gtest_pred_impl.h"
27+
28+
namespace doris::simd {
29+
TEST(BitsTest, BytesMaskToBitsMask) {
30+
// Length determined by architecture (16 on NEON/aarch64, else 32)
31+
constexpr auto len = bits_mask_length();
32+
std::vector<uint8_t> data(len, 0);
33+
// Mark some indices as 1 (non‑zero)
34+
std::vector<size_t> marked = {0, len / 2, len - 1};
35+
for (auto i : marked) {
36+
data[i] = 1;
37+
}
38+
39+
// Build mask
40+
auto mask = bytes_mask_to_bits_mask(data.data());
41+
42+
// Collect indices via iterate_through_bits_mask
43+
std::vector<size_t> collected;
44+
iterate_through_bits_mask([&](size_t idx) { collected.push_back(idx); }, mask);
45+
46+
// Sort to compare (iterate_through_bits_mask already gives ascending, but be safe)
47+
std::sort(collected.begin(), collected.end());
48+
49+
// Expect collected matches 'marked'
50+
EXPECT_EQ(collected.size(), marked.size());
51+
for (size_t i = 0; i < marked.size(); ++i) {
52+
EXPECT_EQ(collected[i], marked[i]);
53+
}
54+
55+
// All zero -> mask == 0
56+
std::vector<uint8_t> zeros(len, 0);
57+
auto zero_mask = bytes_mask_to_bits_mask(zeros.data());
58+
EXPECT_EQ(zero_mask, decltype(zero_mask)(0));
59+
60+
// All ones -> mask == bits_mask_all()
61+
std::vector<uint8_t> ones(len, 1);
62+
auto full_mask = bytes_mask_to_bits_mask(ones.data());
63+
EXPECT_EQ(full_mask, bits_mask_all());
64+
}
65+
66+
TEST(BitsTest, CountZeroNum) {
67+
// Case 1: empty
68+
const int8_t* empty = nullptr;
69+
EXPECT_EQ(count_zero_num<size_t>(empty, size_t(0)), 0U);
70+
EXPECT_EQ(count_zero_num<size_t>(empty, static_cast<const uint8_t*>(nullptr), size_t(0)), 0U);
71+
72+
// Case 2: all zero
73+
{
74+
std::vector<int8_t> v(10, 0);
75+
std::vector<uint8_t> null_map(10, 0);
76+
EXPECT_EQ(count_zero_num<size_t>(v.data(), v.size()), 10U);
77+
EXPECT_EQ(count_zero_num<size_t>(v.data(), null_map.data(), v.size()), 10U);
78+
}
79+
80+
// Case 3: no zero, some nulls
81+
{
82+
std::vector<int8_t> v = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
83+
std::vector<uint8_t> null_map = {0, 1, 0, 0, 1, 0, 0, 1, 0, 0};
84+
EXPECT_EQ(count_zero_num<size_t>(v.data(), v.size()), 0U);
85+
EXPECT_EQ(count_zero_num<size_t>(v.data(), null_map.data(), v.size()), 3U);
86+
}
87+
88+
// Case 4: mixed zeros and nulls union
89+
{
90+
// zeros at 0,2,5 ; nulls at 1,4,6
91+
std::vector<int8_t> v = {0, 1, 0, 1, 1, 0, 1, 1};
92+
std::vector<uint8_t> null_map = {0, 1, 0, 0, 1, 0, 1, 0};
93+
EXPECT_EQ(count_zero_num<size_t>(v.data(), v.size()), 3U);
94+
EXPECT_EQ(count_zero_num<size_t>(v.data(), null_map.data(), v.size()), 6U);
95+
}
96+
97+
// Case 5: large (>64) to exercise SIMD path
98+
{
99+
std::vector<int8_t> v(128);
100+
std::vector<uint8_t> null_map(128);
101+
size_t expect_zero = 0;
102+
size_t expect_union = 0;
103+
for (size_t i = 0; i < v.size(); ++i) {
104+
v[i] = (i % 5 == 0) ? 0 : 1;
105+
null_map[i] = (i % 7 == 0) ? 1 : 0;
106+
if (v[i] == 0) {
107+
++expect_zero;
108+
}
109+
expect_union += static_cast<uint8_t>(!v[i]) | null_map[i];
110+
}
111+
EXPECT_EQ(count_zero_num<size_t>(v.data(), v.size()), expect_zero);
112+
EXPECT_EQ(count_zero_num<size_t>(v.data(), null_map.data(), v.size()), expect_union);
113+
}
114+
115+
// Case 6: tail check (size not multiple of 16/64)
116+
{
117+
size_t n = 128 + 13;
118+
std::vector<int8_t> v(n);
119+
std::vector<uint8_t> null_map(n);
120+
size_t expect_zero = 0;
121+
size_t expect_union = 0;
122+
for (size_t i = 0; i < n; ++i) {
123+
v[i] = (i % 5 == 0) ? 0 : 1;
124+
null_map[i] = (i % 7 == 0) ? 1 : 0;
125+
if (v[i] == 0) {
126+
++expect_zero;
127+
}
128+
expect_union += static_cast<uint8_t>(!v[i]) | null_map[i];
129+
}
130+
EXPECT_EQ(count_zero_num<size_t>(v.data(), n), expect_zero);
131+
EXPECT_EQ(count_zero_num<size_t>(v.data(), null_map.data(), n), expect_union);
132+
}
133+
}
134+
135+
TEST(BitsTest, FindByte) {
136+
std::vector<uint8_t> v = {5, 0, 1, 7, 1, 9, 0, 3};
137+
EXPECT_EQ(find_byte<uint8_t>(v, 0, uint8_t(5)), 0U);
138+
EXPECT_EQ(find_byte<uint8_t>(v, 0, uint8_t(0)), 1U);
139+
EXPECT_EQ(find_byte<uint8_t>(v, 2, uint8_t(1)), 2U);
140+
EXPECT_EQ(find_byte<uint8_t>(v, 3, uint8_t(1)), 4U);
141+
EXPECT_EQ(find_byte<uint8_t>(v, 0, uint8_t(42)), v.size());
142+
EXPECT_EQ(find_byte<uint8_t>(v, v.size(), uint8_t(5)), v.size());
143+
const uint8_t* data = v.data();
144+
EXPECT_EQ(find_byte<uint8_t>(data, 0, 5, uint8_t(0)), 1U);
145+
EXPECT_EQ(find_byte<uint8_t>(data, 2, 6, uint8_t(1)), 2U);
146+
EXPECT_EQ(find_byte<uint8_t>(data, 3, 6, uint8_t(0)), 6U);
147+
EXPECT_EQ(find_byte<uint8_t>(data, 6, 6, uint8_t(3)), 6U);
148+
}
149+
150+
TEST(BitsTest, ContainByte) {
151+
std::vector<uint8_t> v = {5, 0, 1, 7, 1, 9, 0, 3};
152+
const uint8_t* data = v.data();
153+
EXPECT_TRUE(contain_byte<uint8_t>(data, v.size(), static_cast<signed char>(5)));
154+
EXPECT_TRUE(contain_byte<uint8_t>(data, v.size(), static_cast<signed char>(0)));
155+
EXPECT_TRUE(contain_byte<uint8_t>(data, v.size(), static_cast<signed char>(1)));
156+
EXPECT_TRUE(contain_byte<uint8_t>(data, v.size(), static_cast<signed char>(3)));
157+
EXPECT_FALSE(contain_byte<uint8_t>(data, v.size(), static_cast<signed char>(42)));
158+
EXPECT_FALSE(contain_byte<uint8_t>(data, 0, static_cast<signed char>(5)));
159+
}
160+
161+
TEST(BitsTest, FindOne) {
162+
std::vector<uint8_t> v = {5, 0, 1, 7, 1, 9, 0, 3};
163+
const uint8_t* data = v.data();
164+
EXPECT_EQ(find_one(v, 0), 2U);
165+
EXPECT_EQ(find_one(v, 3), 4U);
166+
EXPECT_EQ(find_one(v, 5), v.size());
167+
EXPECT_EQ(find_one(data, 0, v.size()), 2U);
168+
EXPECT_EQ(find_one(data, 4, v.size()), 4U);
169+
EXPECT_EQ(find_one(data, 5, v.size()), v.size());
170+
}
171+
172+
TEST(BitsTest, FindZero) {
173+
std::vector<uint8_t> v = {5, 0, 1, 7, 1, 9, 0, 3};
174+
EXPECT_EQ(find_zero(v, 0), 1U);
175+
EXPECT_EQ(find_zero(v, 2), 6U);
176+
EXPECT_EQ(find_zero(v, 7), v.size());
177+
}
178+
} //namespace doris::simd

0 commit comments

Comments
 (0)