Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 10 additions & 86 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,17 +1475,13 @@ void CodeGen_ARM::visit(const Store *op) {
is_float16_and_has_feature(elt) ||
elt == Int(8) || elt == Int(16) || elt == Int(32) || elt == Int(64) ||
elt == UInt(8) || elt == UInt(16) || elt == UInt(32) || elt == UInt(64)) {
// TODO(zvookin): Handle vector_bits_*.
const int target_vector_bits = native_vector_bits();
if (vec_bits % 128 == 0) {
type_ok_for_vst = true;
int target_vector_bits = native_vector_bits();
if (target_vector_bits == 0) {
target_vector_bits = 128;
}
intrin_type = intrin_type.with_lanes(target_vector_bits / t.bits());
} else if (vec_bits % 64 == 0) {
type_ok_for_vst = true;
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? 128 : 64;
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? target_vector_bits : 64;
intrin_type = intrin_type.with_lanes(intrin_bits / t.bits());
}
}
Expand All @@ -1494,7 +1490,9 @@ void CodeGen_ARM::visit(const Store *op) {
if (ramp && is_const_one(ramp->stride) &&
shuffle && shuffle->is_interleave() &&
type_ok_for_vst &&
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) {
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4 &&
// TODO: we could handle predicated_store once shuffle_vector gets robust for scalable vectors
!is_predicated_store) {

const int num_vecs = shuffle->vectors.size();
vector<Value *> args(num_vecs);
Expand All @@ -1513,7 +1511,6 @@ void CodeGen_ARM::visit(const Store *op) {
for (int i = 0; i < num_vecs; ++i) {
args[i] = codegen(shuffle->vectors[i]);
}
Value *store_pred_val = codegen(op->predicate);

bool is_sve = target.has_feature(Target::SVE2);

Expand Down Expand Up @@ -1559,8 +1556,8 @@ void CodeGen_ARM::visit(const Store *op) {
llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
internal_assert(fn);

// SVE2 supports predication for smaller than whole vector size.
internal_assert(target.has_feature(Target::SVE2) || (t.lanes() >= intrin_type.lanes()));
// Scalable vector supports predication for smaller than whole vector size.
internal_assert(target_vscale() > 0 || (t.lanes() >= intrin_type.lanes()));

for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
Expr slice_base = simplify(ramp->base + i * num_vecs);
Expand All @@ -1581,15 +1578,10 @@ void CodeGen_ARM::visit(const Store *op) {
slice_args.push_back(ConstantInt::get(i32_t, alignment));
} else {
if (is_sve) {
// Set the predicate argument
// Set the predicate argument to mask active lanes
auto active_lanes = std::min(t.lanes() - i, intrin_type.lanes());
Value *vpred_val;
if (is_predicated_store) {
vpred_val = slice_vector(store_pred_val, i, intrin_type.lanes());
} else {
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
vpred_val = codegen(vpred);
}
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
Value *vpred_val = codegen(vpred);
slice_args.push_back(vpred_val);
}
// Set the pointer argument
Expand Down Expand Up @@ -1810,74 +1802,6 @@ void CodeGen_ARM::visit(const Load *op) {
CodeGen_Posix::visit(op);
return;
}
} else if (stride && (2 <= stride->value && stride->value <= 4)) {
// Structured load ST2/ST3/ST4 of SVE

Expr base = ramp->base;
ModulusRemainder align = op->alignment;

int aligned_stride = gcd(stride->value, align.modulus);
int offset = 0;
if (aligned_stride == stride->value) {
offset = mod_imp((int)align.remainder, aligned_stride);
} else {
const Add *add = base.as<Add>();
if (const IntImm *add_c = add ? add->b.as<IntImm>() : base.as<IntImm>()) {
offset = mod_imp(add_c->value, stride->value);
}
}

if (offset) {
base = simplify(base - offset);
}

Value *load_pred_val = codegen(op->predicate);

// We need to slice the result in to native vector lanes to use sve intrin.
// LLVM will optimize redundant ld instructions afterwards
const int slice_lanes = target.natural_vector_size(op->type);
vector<Value *> results;
for (int i = 0; i < op->type.lanes(); i += slice_lanes) {
int load_base_i = i * stride->value;
Expr slice_base = simplify(base + load_base_i);
Expr slice_index = Ramp::make(slice_base, stride, slice_lanes);
std::ostringstream instr;
instr << "llvm.aarch64.sve.ld"
<< stride->value
<< ".sret.nxv"
<< slice_lanes
<< (op->type.is_float() ? 'f' : 'i')
<< op->type.bits();
llvm::Type *elt = llvm_type_of(op->type.element_of());
llvm::Type *slice_type = get_vector_type(elt, slice_lanes);
StructType *sret_type = StructType::get(module->getContext(), std::vector(stride->value, slice_type));
std::vector<llvm::Type *> arg_types{get_vector_type(i1_t, slice_lanes), ptr_t};
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);

// Set the predicate argument
int active_lanes = std::min(op->type.lanes() - i, slice_lanes);

Expr vpred = make_vector_predicate_1s_0s(active_lanes, slice_lanes - active_lanes);
Value *vpred_val = codegen(vpred);
vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, get_vector_type(vpred_val->getType()->getScalarType(), slice_lanes));
if (is_predicated_load) {
Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, slice_lanes);
vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val);
}

Value *elt_ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base);
CallInst *load_i = builder->CreateCall(fn, {vpred_val, elt_ptr});
add_tbaa_metadata(load_i, op->name, slice_index);
// extract one element out of returned struct
Value *extracted = builder->CreateExtractValue(load_i, offset);
results.push_back(extracted);
}

// Retrieve original lanes
value = concat_vectors(results);
value = slice_vector(value, 0, op->type.lanes());
return;
} else if (op->index.type().is_vector()) {
// General Gather Load

Expand Down
110 changes: 52 additions & 58 deletions test/correctness/simd_op_check_sve2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,79 +677,81 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
vector<tuple<Type, CastFuncTy>> test_params = {
{Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}};

const int base_vec_bits = has_sve() ? target.vector_bits : 128;
const int vscale = base_vec_bits / 128;

for (const auto &[elt, in_im] : test_params) {
const int bits = elt.bits();
if ((elt == Float(16) && !is_float16_supported()) ||
(is_arm32() && bits == 64)) {
continue;
}

// LD/ST - Load/Store
for (int width = 64; width <= 64 * 4; width *= 2) {
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
// LD/ST - Load/Store scalar
// We skip scalar load/store test due to the following challenges.
// The rule by which LLVM selects instruction does not seem simple.
// For example, ld1, ldr, or ldp is used for instruction and z or q register is used for operand,
// depending on data type, vscale, what is performed before/after load, and LLVM version.
// The other thing is, load/store instruction appears in other place than we want to check,
// which makes it prone to false-positive detection as we only search strings line-by-line.

// In case of arm32, instruction selection looks inconsistent due to optimization by LLVM
AddTestFunctor add(*this, bits, total_lanes, target.bits == 64);
// NOTE: if the expr is too simple, LLVM might generate "bl memcpy"
Expr load_store_1 = in_im(x) * 3;
// LDn - Structured Load strided elements
if (Halide::Internal::get_llvm_version() >= 220) {
for (int stride = 2; stride <= 4; ++stride) {

if (has_sve()) {
// This pattern has changed with LLVM 21, see https://github.com/halide/Halide/issues/8584 for more
// details.
if (Halide::Internal::get_llvm_version() < 210) {
// in native width, ld1b/st1b is used regardless of data type
const bool allow_byte_ls = (width == target.vector_bits);
add({get_sve_ls_instr("ld1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
add({get_sve_ls_instr("st1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
for (int factor : {1, 2, 4}) {
const int vector_lanes = base_vec_bits * factor / bits;

// In StageStridedLoads.cpp (stride < r->lanes) is the condition for staging to happen
// See https://github.com/halide/Halide/issues/8819
if (vector_lanes <= stride) continue;

AddTestFunctor add_ldn(*this, bits, vector_lanes);

Expr load_n = in_im(x * stride) + in_im(x * stride + stride - 1);

const string ldn_str = "ld" + to_string(stride);
if (has_sve()) {
add_ldn({get_sve_ls_instr(ldn_str, bits)}, vector_lanes, load_n);
} else {
add_ldn(sel_op("v" + ldn_str + ".", ldn_str), load_n);
}
}
} else {
// vector register is not used for simple load/store
string reg_prefix = (width <= 64) ? "d" : "q";
add({{"st[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1);
add({{"ld[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1);
}
}

// LD2/ST2 - Load/Store two-element structures
int base_vec_bits = has_sve() ? target.vector_bits : 128;
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
// ST2 - Store two-element structures
for (int factor : {1, 2}) {
const int width = base_vec_bits * 2 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = min(vector_lanes, base_vec_bits / bits);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
tmp1(x) = cast(elt, x);
tmp1.compute_root();
tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_2 = in_im(x * 2) + in_im(x * 2 + 1);
Expr store_2 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(inssue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
#endif
add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2);
} else {
add_ldn(sel_op("vld2.", "ld2"), load_2);
add_stn(sel_op("vst2.", "st2"), store_2);
}
}

// Also check when the two expressions interleaved have a common
// subexpression, which results in a vector var being lifted out.
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
for (int factor : {1, 2}) {
const int width = base_vec_bits * 2 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Expand All @@ -768,14 +770,14 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}
}

// LD3/ST3 - Store three-element structures
for (int width = 192; width <= 192 * 4; width *= 2) {
// ST3 - Store three-element structures
for (int factor : {1, 2}) {
const int width = base_vec_bits * 3 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 3;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
Expand All @@ -785,29 +787,25 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
x % 3 == 1, tmp1(x / 3 + 16),
tmp1(x / 3 + 32));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_3 = in_im(x * 3) + in_im(x * 3 + 1) + in_im(x * 3 + 2);
Expr store_3 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
#endif
if (Halide::Internal::get_llvm_version() >= 220) {
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
}
} else {
add_ldn(sel_op("vld3.", "ld3"), load_3);
add_stn(sel_op("vst3.", "st3"), store_3);
}
}

// LD4/ST4 - Store four-element structures
for (int width = 256; width <= 256 * 4; width *= 2) {
// ST4 - Store four-element structures
for (int factor : {1, 2}) {
const int width = base_vec_bits * 4 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 4;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
Expand All @@ -818,17 +816,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
x % 4 == 2, tmp1(x / 4 + 32),
tmp1(x / 4 + 48));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_4 = in_im(x * 4) + in_im(x * 4 + 1) + in_im(x * 4 + 2) + in_im(x * 4 + 3);
Expr store_4 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
#endif
if (Halide::Internal::get_llvm_version() >= 220) {
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
}
} else {
add_ldn(sel_op("vld4.", "ld4"), load_4);
add_stn(sel_op("vst4.", "st4"), store_4);
}
}
Expand All @@ -838,7 +832,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
for (int width = 64; width <= 64 * 4; width *= 2) {
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (total_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add(*this, bits, total_lanes);
Expr index = clamp(cast<int>(in_im(x)), 0, W - 1);
Expand Down
Loading