diff --git a/aie_kernels/aie2/gelu.cc b/aie_kernels/aie2/gelu.cc index 57f426bd..d5f1f857 100644 --- a/aie_kernels/aie2/gelu.cc +++ b/aie_kernels/aie2/gelu.cc @@ -13,10 +13,8 @@ void gelu_tanh_approx_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict o { event0(); - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); - - aie::vector input; + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); // Constants const bfloat16 k0_5 = 0.5f; @@ -24,33 +22,33 @@ void gelu_tanh_approx_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict o const bfloat16 sqrt_2_over_pi = 0.79788456f; // ≈ sqrt(2/π) const bfloat16 kBeta = 0.044715f; - auto v05 = aie::broadcast(k0_5); - auto v1 = aie::broadcast(k1); - auto vs2opi = aie::broadcast(sqrt_2_over_pi); - auto vBeta = aie::broadcast(kBeta); + auto v05 = aie::broadcast(k0_5); + auto v1 = aie::broadcast(k1); + auto vs2opi = aie::broadcast(sqrt_2_over_pi); + auto vBeta = aie::broadcast(kBeta); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < vector_size; i += 16) { - input = *it_in++; - auto x = input; + for (int i = 0; i < vector_size; i += 32) { + auto x = *it_in++; // Compute x^3 - aie::vector x2 = aie::mul(x, x); // x^2 - aie::vector x3 = aie::mul(x, x2); // x^3 + aie::vector x2 = aie::mul(x, x); // x^2 + aie::vector x3 = aie::mul(x, x2); // x^3 // inner = sqrt(2/pi) * (x + 0.044715 * x^3) - aie::vector x3_beta = aie::mul(x3, vBeta); - aie::vector inner = aie::add(x, x3_beta); - auto inner1 = aie::mul(inner, vs2opi); + aie::vector x3_beta = aie::mul(x3, vBeta); + aie::vector inner = aie::add(x, x3_beta); + aie::vector inner1 = aie::mul(inner, vs2opi); - // tanh_out = tanh(inner) - aie::vector tanh_out = getTanhBf16(inner1.to_vector()); + // LUT-based tanh: split to 16-wide halves + aie::vector tanh_lo = getTanhBf16(inner1.extract<16>(0)); + aie::vector tanh_hi = getTanhBf16(inner1.extract<16>(1)); + aie::vector tanh_out = aie::concat(tanh_lo, tanh_hi); // result = 0.5 * x * (1 + tanh_out) - aie::vector one_plus_tanh = aie::add(tanh_out, v1); - // Multiply by x and 0.5 - aie::vector mul_v05 = aie::mul(v05, one_plus_tanh); + aie::vector one_plus_tanh = aie::add(tanh_out, v1); + aie::vector mul_v05 = aie::mul(v05, one_plus_tanh); auto result = aie::mul(x, mul_v05); *it_out++ = result.to_vector(); diff --git a/aie_kernels/aie2/relu.cc b/aie_kernels/aie2/relu.cc index fd47379d..6ec4b8c0 100644 --- a/aie_kernels/aie2/relu.cc +++ b/aie_kernels/aie2/relu.cc @@ -15,7 +15,7 @@ void relu_vectorized_bf16(bfloat16 *restrict a, bfloat16 *restrict c, const int3 { event0(); - const int v_factor = 16; + const int v_factor = 32; v32bfloat16 zeroes = broadcast_zero_to_v32bfloat16(); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_RANGE(16, 16) diff --git a/aie_kernels/aie2/rms_norm.cc b/aie_kernels/aie2/rms_norm.cc index 7c9d7c2b..8098c3c8 100644 --- a/aie_kernels/aie2/rms_norm.cc +++ b/aie_kernels/aie2/rms_norm.cc @@ -69,11 +69,11 @@ void rms_norm_general(const T *restrict input, const T *restrict input2, T *rest extern "C" { void rms_norm_bf16_vector(bfloat16 *input, bfloat16 *output, int32_t size) { - rms_norm_general(input, nullptr, output, size); + rms_norm_general(input, nullptr, output, size); } void weighted_rms_norm(bfloat16 *a_in, bfloat16 *b_in, bfloat16 *c_out, int32_t size) { - rms_norm_general(a_in, b_in, c_out, size); + rms_norm_general(a_in, b_in, c_out, size); } } diff --git a/aie_kernels/aie2/sigmoid.cc b/aie_kernels/aie2/sigmoid.cc index a740fe06..c89077dc 100644 --- a/aie_kernels/aie2/sigmoid.cc +++ b/aie_kernels/aie2/sigmoid.cc @@ -15,24 +15,27 @@ void sigmoid_tanh_approx_bf16(bfloat16 *restrict input_vector, { event0(); - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); - aie::vector register_0_5 = aie::broadcast(0.5f); - aie::vector register_1 = aie::broadcast(1.0f); + aie::vector register_0_5 = aie::broadcast(0.5f); + aie::vector register_1 = aie::broadcast(1.0f); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < vector_size; i += 16) { - // Load input vector - aie::vector input = *it_in++; + for (int i = 0; i < vector_size; i += 32) { + auto input = *it_in++; - // Compute tanh approximation - aie::vector half_x = aie::mul(input, register_0_5); - aie::vector tanh_half_x = getTanhBf16(half_x); - auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); - aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); + // Compute half_x = x * 0.5 + aie::vector half_x = aie::mul(input, register_0_5); + + // LUT-based tanh: split to 16-wide halves + aie::vector tanh_lo = getTanhBf16(half_x.extract<16>(0)); + aie::vector tanh_hi = getTanhBf16(half_x.extract<16>(1)); + aie::vector tanh_half_x = aie::concat(tanh_lo, tanh_hi); + + auto one_plus = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(one_plus, register_0_5); - // Store output vector *it_out++ = sigmoid_approx; } diff --git a/aie_kernels/aie2/silu.cc b/aie_kernels/aie2/silu.cc index 3ab9b9aa..3b364b17 100644 --- a/aie_kernels/aie2/silu.cc +++ b/aie_kernels/aie2/silu.cc @@ -13,26 +13,28 @@ void silu_tanh_approx_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict o { event0(); - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); - aie::vector register_0_5 = aie::broadcast(0.5f); - aie::vector register_1 = aie::broadcast(1.0f); + aie::vector register_0_5 = aie::broadcast(0.5f); + aie::vector register_1 = aie::broadcast(1.0f); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < vector_size; i += 16) { - // Load input vector - aie::vector input = *it_in++; - - // Compute tanh approximation - aie::vector half_x = aie::mul(input, register_0_5); - aie::vector tanh_half_x = getTanhBf16(half_x); - auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); - aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); - // Compute output: x * tanh_approx + for (int i = 0; i < vector_size; i += 32) { + auto input = *it_in++; + + // Compute half_x = x * 0.5 + aie::vector half_x = aie::mul(input, register_0_5); + + // LUT-based tanh: split to 16-wide halves + aie::vector tanh_lo = getTanhBf16(half_x.extract<16>(0)); + aie::vector tanh_hi = getTanhBf16(half_x.extract<16>(1)); + aie::vector tanh_half_x = aie::concat(tanh_lo, tanh_hi); + + auto one_plus = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(one_plus, register_0_5); auto mul_output = aie::mul(input, sigmoid_approx); - // Store output vector *it_out++ = mul_output.to_vector(); } diff --git a/aie_kernels/aie2/tanh.cc b/aie_kernels/aie2/tanh.cc index 186d0c2c..f7e3e33b 100644 --- a/aie_kernels/aie2/tanh.cc +++ b/aie_kernels/aie2/tanh.cc @@ -13,20 +13,19 @@ void tanh_bf16_vectorized(bfloat16 *restrict input_vector, bfloat16 *restrict ou { event0(); - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < vector_size; i += 16) { - // Load input vector - aie::vector input = *it_in++; + for (int i = 0; i < vector_size; i += 32) { + auto input = *it_in++; - // Compute tanh approximation - aie::vector tanh_x = getTanhBf16(input); + // LUT-based tanh: split to 16-wide halves + aie::vector tanh_lo = getTanhBf16(input.extract<16>(0)); + aie::vector tanh_hi = getTanhBf16(input.extract<16>(1)); - // Store output vector - *it_out++ = tanh_x; + *it_out++ = aie::concat(tanh_lo, tanh_hi); } event1(); diff --git a/aie_kernels/aie2p/gelu.cc b/aie_kernels/aie2p/gelu.cc index c964feb9..6c1f3b82 100644 --- a/aie_kernels/aie2p/gelu.cc +++ b/aie_kernels/aie2p/gelu.cc @@ -12,10 +12,8 @@ void gelu_tanh_approx_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict o { event0(); - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); - - aie::vector input; + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); // Constants const bfloat16 k0_5 = 0.5f; @@ -23,33 +21,34 @@ void gelu_tanh_approx_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict o const bfloat16 sqrt_2_over_pi = 0.79788456f; // ≈ sqrt(2/π) const bfloat16 kBeta = 0.044715f; - auto v05 = aie::broadcast(k0_5); - auto v1 = aie::broadcast(k1); + auto v05 = aie::broadcast(k0_5); + auto v1 = aie::broadcast(k1); auto vs2opi = aie::broadcast(sqrt_2_over_pi); - auto vBeta = aie::broadcast(kBeta); + auto vBeta = aie::broadcast(kBeta); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < vector_size; i += 16) { - input = *it_in++; - auto x = input; + for (int i = 0; i < vector_size; i += 32) { + auto x = *it_in++; // Compute x^3 - aie::vector x2 = aie::mul(x, x); // x^2 - aie::vector x3 = aie::mul(x, x2); // x^3 + aie::vector x2 = aie::mul(x, x); // x^2 + aie::vector x3 = aie::mul(x, x2); // x^3 // inner = sqrt(2/pi) * (x + 0.044715 * x^3) - aie::vector x3_beta = aie::mul(x3, vBeta); - aie::vector inner = aie::add(x, x3_beta); - auto inner1 = aie::mul(inner, vs2opi); + aie::vector x3_beta = aie::mul(x3, vBeta); + aie::vector inner = aie::add(x, x3_beta); - // tanh_out = tanh(inner) - auto tanh_out = aie::tanh(inner1.to_vector()); + // tanh operates on 16 float lanes; split to two halves + auto inner1_lo = aie::mul(inner.extract<16>(0), vs2opi); + auto inner1_hi = aie::mul(inner.extract<16>(1), vs2opi); + auto tanh_lo = aie::tanh(inner1_lo.to_vector()); + auto tanh_hi = aie::tanh(inner1_hi.to_vector()); + aie::vector tanh_out = aie::concat(tanh_lo, tanh_hi); // result = 0.5 * x * (1 + tanh_out) - aie::vector one_plus_tanh = aie::add(tanh_out, v1); - // Multiply by x and 0.5 - aie::vector mul_v05 = aie::mul(v05, one_plus_tanh); + aie::vector one_plus_tanh = aie::add(tanh_out, v1); + aie::vector mul_v05 = aie::mul(v05, one_plus_tanh); auto result = aie::mul(x, mul_v05); *it_out++ = result.to_vector(); diff --git a/aie_kernels/aie2p/layer_norm.cc b/aie_kernels/aie2p/layer_norm.cc index b9ef4cc4..834f71ee 100644 --- a/aie_kernels/aie2p/layer_norm.cc +++ b/aie_kernels/aie2p/layer_norm.cc @@ -103,6 +103,6 @@ extern "C" { void layer_norm(bfloat16 *input, bfloat16 *output, int32_t cols) { ::aie::set_rounding(aie::rounding_mode::conv_even); - layer_norm(input, output, cols); + layer_norm(input, output, cols); } } diff --git a/aie_kernels/aie2p/rms_norm.cc b/aie_kernels/aie2p/rms_norm.cc index 1a709309..4ef24edc 100644 --- a/aie_kernels/aie2p/rms_norm.cc +++ b/aie_kernels/aie2p/rms_norm.cc @@ -67,11 +67,11 @@ void rms_norm_general(const T *restrict input, const T *restrict input2, T *rest extern "C" { void rms_norm_bf16_vector(bfloat16 *input, bfloat16 *output, int32_t size) { - rms_norm_general(input, nullptr, output, size); + rms_norm_general(input, nullptr, output, size); } void weighted_rms_norm(bfloat16 *a_in, bfloat16 *b_in, bfloat16 *c_out, int32_t size) { - rms_norm_general(a_in, b_in, c_out, size); + rms_norm_general(a_in, b_in, c_out, size); } } diff --git a/aie_kernels/aie2p/sigmoid.cc b/aie_kernels/aie2p/sigmoid.cc index 95a88810..59b2005a 100644 --- a/aie_kernels/aie2p/sigmoid.cc +++ b/aie_kernels/aie2p/sigmoid.cc @@ -15,26 +15,27 @@ void sigmoid_tanh_approx_bf16(bfloat16 *restrict input_vector, event0(); int num_elems = vector_size; - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); - aie::vector input; - aie::vector output; aie::vector register_0_5 = aie::broadcast(0.5f); - aie::vector register_1 = aie::broadcast(1.0f); + aie::vector register_1 = aie::broadcast(1.0f); + aie::vector register_0_5_wide = aie::broadcast(0.5f); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < num_elems; i += 16) { - // Load input vector - input = *it_in++; + for (int i = 0; i < num_elems; i += 32) { + auto input = *it_in++; - // Compute tanh approximation - auto half_x = aie::mul(input, register_0_5); - auto tanh_half_x = aie::tanh(half_x.to_vector()); - auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); - aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); + // tanh(x/2) split to two 16-wide halves + auto half_x_lo = aie::mul(input.extract<16>(0), register_0_5); + auto half_x_hi = aie::mul(input.extract<16>(1), register_0_5); + auto tanh_lo = aie::tanh(half_x_lo.to_vector()); + auto tanh_hi = aie::tanh(half_x_hi.to_vector()); + aie::vector tanh_half_x = aie::concat(tanh_lo, tanh_hi); + + auto one_plus = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(one_plus, register_0_5_wide); - // Store output vector *it_out++ = sigmoid_approx; } diff --git a/aie_kernels/aie2p/silu.cc b/aie_kernels/aie2p/silu.cc index 2610a80e..e6ca3e14 100644 --- a/aie_kernels/aie2p/silu.cc +++ b/aie_kernels/aie2p/silu.cc @@ -13,28 +13,28 @@ void silu_tanh_approx_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict o event0(); int num_elems = vector_size; - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); - aie::vector input; - aie::vector output; aie::vector register_0_5 = aie::broadcast(0.5f); - aie::vector register_1 = aie::broadcast(1.0f); + aie::vector register_1 = aie::broadcast(1.0f); + aie::vector register_0_5_wide = aie::broadcast(0.5f); AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < num_elems; i += 16) { - // Load input vector - input = *it_in++; - - // Compute tanh approximation - auto half_x = aie::mul(input, register_0_5); - auto tanh_half_x = aie::tanh(half_x.to_vector()); - auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); - aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); - // Compute output: x * tanh_approx + for (int i = 0; i < num_elems; i += 32) { + auto input = *it_in++; + + // tanh(x/2) split to two 16-wide halves + auto half_x_lo = aie::mul(input.extract<16>(0), register_0_5); + auto half_x_hi = aie::mul(input.extract<16>(1), register_0_5); + auto tanh_lo = aie::tanh(half_x_lo.to_vector()); + auto tanh_hi = aie::tanh(half_x_hi.to_vector()); + aie::vector tanh_half_x = aie::concat(tanh_lo, tanh_hi); + + auto one_plus = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(one_plus, register_0_5_wide); auto mul_output = aie::mul(input, sigmoid_approx); - // Store output vector *it_out++ = mul_output.to_vector(); } diff --git a/aie_kernels/aie2p/tanh.cc b/aie_kernels/aie2p/tanh.cc index 76c90988..5d70d4ec 100644 --- a/aie_kernels/aie2p/tanh.cc +++ b/aie_kernels/aie2p/tanh.cc @@ -13,24 +13,23 @@ void tanh_bf16_vectorized(bfloat16 *restrict input_vector, bfloat16 *restrict ou event0(); int num_elems = vector_size; - auto it_in = aie::begin_restrict_vector<16>((bfloat16 *)input_vector); - auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + auto it_in = aie::begin_restrict_vector<32>((bfloat16 *)input_vector); + auto it_out = aie::begin_restrict_vector<32>((bfloat16 *)output_vector); - aie::vector input; - aie::accum acc; - aie::vector output; AIE_PREPARE_FOR_PIPELINING AIE_LOOP_MIN_ITERATION_COUNT(64) - for (int i = 0; i < num_elems; i += 16) { - // Load input vector - input = *it_in++; - - // Compute tanh - acc.from_vector(input, 0); - auto tanh_x = aie::tanh(acc.to_vector()); - - // Store output vector - *it_out++ = tanh_x; + for (int i = 0; i < num_elems; i += 32) { + auto input = *it_in++; + + // vtanh operates on 16 float lanes; split to two halves + aie::accum acc_lo; + aie::accum acc_hi; + acc_lo.from_vector(input.extract<16>(0), 0); + acc_hi.from_vector(input.extract<16>(1), 0); + auto tanh_lo = aie::tanh(acc_lo.to_vector()); + auto tanh_hi = aie::tanh(acc_hi.to_vector()); + + *it_out++ = aie::concat(tanh_lo, tanh_hi); } event1(); diff --git a/aie_kernels/generic/add.cc b/aie_kernels/generic/add.cc index 8de54381..338a0c3f 100644 --- a/aie_kernels/generic/add.cc +++ b/aie_kernels/generic/add.cc @@ -21,7 +21,7 @@ template void eltwise_add(T_in *a, T_in *b, T_ou template void eltwise_vadd(T_in *a, T_in *b, T_out *c, int size) { - constexpr int vec_factor = 16; + constexpr int vec_factor = 32; event0(); T_in *__restrict pA1 = a; T_in *__restrict pB1 = b; diff --git a/aie_kernels/generic/mul.cc b/aie_kernels/generic/mul.cc index 6f109431..500cde88 100644 --- a/aie_kernels/generic/mul.cc +++ b/aie_kernels/generic/mul.cc @@ -20,9 +20,9 @@ template void eltwise_vmul(T_in *a, T_in *b, T_o { event0(); - for (int i = 0; i < size; i += 16) { - auto A = aie::load_v<16>(a + i); - auto B = aie::load_v<16>(b + i); + for (int i = 0; i < size; i += 32) { + auto A = aie::load_v<32>(a + i); + auto B = aie::load_v<32>(b + i); auto C = aie::mul(A, B).template to_vector(); aie::store_v(c + i, C); } diff --git a/aie_kernels/generic/rope.cc b/aie_kernels/generic/rope.cc index 868016c5..aafd0c1d 100644 --- a/aie_kernels/generic/rope.cc +++ b/aie_kernels/generic/rope.cc @@ -75,9 +75,9 @@ extern "C" { void rope(bfloat16 *input, bfloat16 *lut, bfloat16 *output, int32_t dims) { #if defined(TWO_HALVES) - rope_kernel_two_halves(input, lut, output, dims); // For the two-halves method used in HF transformers + rope_kernel_two_halves(input, lut, output, dims); // For the two-halves method used in HF transformers #elif defined(INTERLEAVED) - rope_kernel_interleaved( + rope_kernel_interleaved( input, lut, output, dims); // For the interleaved method used in the Llama paper #endif }