diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_helper.h b/onnxruntime/core/providers/cpu/nn/layer_norm_helper.h index ed5ea83d9de30..ac00580c98aff 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_helper.h @@ -6,14 +6,14 @@ #include "core/framework/tensor_shape.h" #include "core/common/status.h" #include "core/common/narrow.h" +#include "core/common/inlined_containers.h" namespace onnxruntime { constexpr const char* kLayerNormInputShapeMismatchError = - "Size of scale and bias (if provided) must match X.shape[axis:], " - "or scale and bias (with same shape) can be broadcasted to X when axis is 2."; + "Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it."; -constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got "; +constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be at least 1, got "; constexpr int64_t kLayerNormInvalidInput = -1; @@ -23,15 +23,29 @@ struct LayerNormParams { int64_t scale_size; int64_t bias_size; int64_t broadcast_param; + bool use_generic_broadcast{false}; // true: full NumPy-style broadcast; false: legacy broadcast_param path + onnxruntime::InlinedVector x_dims; + onnxruntime::InlinedVector x_inner_dims; // X.shape[axis:] + onnxruntime::InlinedVector sc_dims; + onnxruntime::InlinedVector bi_dims; + onnxruntime::InlinedVector sc_strides; + onnxruntime::InlinedVector bi_strides; + int64_t axis{0}; + int64_t last_rank{0}; + onnxruntime::InlinedVector sc_inner_inc; // scale strides for inner dims [axis..] + onnxruntime::InlinedVector bi_inner_inc; // bias strides for inner dims [axis..] + onnxruntime::InlinedVector x_outer_strides; // X strides for outer dims [0..axis-1] }; -// We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns. +// Fast-path broadcasting for axis = 2: // When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S). -// We support scale and bias shape like below: +// We support the following scale/bias shapes in this path: // When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0. // When scale and bias shape is (B, 1, ...), value of broadcast_param is S. // When scale and bias shape is (B, S, ...), value of broadcast_param is 1. // When scale and bias shape is (1, S, ...), value of broadcast_param is -S. +// For all other NumPy-broadcastable shapes we fall back to the generic +// broadcasting path (use_generic_broadcast = true) and ignore broadcast_param. // Below is a macro to compute the offset for scale and bias data for a row of X. #ifndef LAYER_NORM_SCALE_BIAS_OFFSET @@ -48,30 +62,152 @@ class LayerNormHelper { bool has_bias, int64_t axis, LayerNormParams& params) { + // Initialize basic layout parameters: how many rows we have and how many elements + // are normalized per row, as well as the total scale/bias sizes. params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow(axis)); params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); params.scale_size = scale_shape.Size(); - params.bias_size = bias_shape.Size(); + params.bias_size = has_bias ? bias_shape.Size() : 0; + params.broadcast_param = 0; + params.axis = axis; - if (params.norm_size <= 1) { + // Allow norm_size == 1 (scalar normalization is valid according to ONNX spec). + if (params.norm_size < 1) { params.broadcast_param = kLayerNormInvalidInput; return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size); } else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) { params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis); - if (params.broadcast_param == kLayerNormInvalidInput) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - kLayerNormInputShapeMismatchError, - " X.shape=", x_shape, - " scale.shape=", scale_shape, - " bias.shape=", bias_shape, - " and axis=", axis); + // Try to encode simple (B, S, ...) layouts into broadcast_param so that the + // fast-path can be used. If this fails, broadcast_param will be set to + // kLayerNormInvalidInput and we may fall back to generic broadcasting later. + } + const size_t xr = x_shape.NumDimensions(); + const size_t sr = scale_shape.NumDimensions(); + const size_t br = has_bias ? bias_shape.NumDimensions() : 0; + + params.x_dims.clear(); + params.x_dims.reserve(xr); + for (size_t i = 0; i < xr; ++i) { + params.x_dims.push_back(x_shape.GetDims()[i]); + } + + // Right-align scale and bias shapes + params.sc_dims.clear(); + params.sc_dims.resize(xr, 1); + for (size_t i = 0; i < sr; ++i) { + params.sc_dims[xr - 1 - i] = scale_shape.GetDims()[sr - 1 - i]; + } + + params.bi_dims.clear(); + if (has_bias) { + params.bi_dims.resize(xr, 1); + for (size_t i = 0; i < br; ++i) { + params.bi_dims[xr - 1 - i] = bias_shape.GetDims()[br - 1 - i]; + } + } + + // Validate broadcastability + const bool sc_ok = IsNumpyBroadcastable(params.sc_dims, params.x_dims); + const bool bi_ok = !has_bias || IsNumpyBroadcastable(params.bi_dims, params.x_dims); + if (!sc_ok || !bi_ok) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + kLayerNormInputShapeMismatchError, + " X.shape=", x_shape, + " scale.shape=", scale_shape, + " bias.shape=", bias_shape, + " and axis=", axis); + } + + // Compute strides for scale/bias once + params.sc_strides = MakeStrides(params.sc_dims); + params.bi_strides.clear(); + if (has_bias) { + params.bi_strides = MakeStrides(params.bi_dims); + } + + // Detect dependency on outer dimensions [0..axis-1] + bool outer_dep = false; + for (int64_t i = 0; i < axis; ++i) { + const size_t idx = static_cast(i); + if (params.sc_strides[idx] != 0 || + (has_bias && params.bi_strides[idx] != 0)) { + outer_dep = true; + break; } } + + // Decide if we need the generic NumPy-style broadcasting path + params.use_generic_broadcast = outer_dep || (params.broadcast_param == kLayerNormInvalidInput); + + if (params.use_generic_broadcast) { + // Cache inner dims X.shape[axis:] + params.last_rank = onnxruntime::narrow(xr) - axis; + params.x_inner_dims.clear(); + params.x_inner_dims.reserve(params.last_rank > 0 ? static_cast(params.last_rank) : 0); + for (size_t i = static_cast(axis); i < xr; ++i) { + params.x_inner_dims.push_back(params.x_dims[i]); + } + + // Precompute inner increments for scale/bias over [axis..] + params.sc_inner_inc.clear(); + params.bi_inner_inc.clear(); + for (size_t i = static_cast(axis); i < xr; ++i) { + params.sc_inner_inc.push_back(params.sc_strides[i]); + if (has_bias) { + params.bi_inner_inc.push_back(params.bi_strides[i]); + } + } + + // X outer strides [0..axis-1], used only in generic path + params.x_outer_strides.clear(); + params.x_outer_strides.resize(static_cast(axis), 1); + if (axis > 1) { + for (int64_t d = axis - 2; d >= 0; --d) { + const size_t du = static_cast(d); + params.x_outer_strides[du] = + params.x_outer_strides[du + 1] * params.x_dims[du + 1]; + } + } + } else { + // Fast-path: we don't need inner/outer increments + params.last_rank = 0; + params.x_inner_dims.clear(); + params.sc_inner_inc.clear(); + params.bi_inner_inc.clear(); + params.x_outer_strides.clear(); + } + return Status::OK(); } private: + static bool IsNumpyBroadcastable(gsl::span a, + gsl::span b) { + ORT_ENFORCE(a.size() == b.size()); + for (size_t k = 0; k < a.size(); ++k) { + const int64_t ak = a[k]; + const int64_t bk = b[k]; + if (!(ak == 1 || ak == bk)) { + return false; + } + } + return true; + } + static InlinedVector MakeStrides(const InlinedVector& dims) { + InlinedVector strides(dims.size(), 0); + if (dims.empty()) return strides; + + int64_t running = 1; + for (ptrdiff_t i = dims.size() - 1; i >= 0; --i) { + size_t idx = static_cast(i); + strides[idx] = (dims[idx] == 1) ? 0 : running; + running *= std::max(1, dims[idx]); + } + + return strides; + } + static int64_t GetBroadcastParam(const TensorShape& x_shape, const TensorShape& scale_shape, const TensorShape* bias_shape, diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 237110483416c..712dcaf9d7034 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -159,6 +159,256 @@ void ComputeJob( inv_std_dev_data[task_idx] = MLFloat16(1.0f / mean_square); } } +// Write a statistic value (mean or 1/denom) into the output buffer, +// converting from double to the target type U (including MLFloat16). +template +ORT_FORCEINLINE void WriteStat(U* dst, ptrdiff_t index, double v) { + if constexpr (std::is_same_v) { + dst[index] = MLFloat16(static_cast(v)); + } else { + dst[index] = gsl::narrow_cast(v); + } +} +template +struct NormalizationMath { + static double LoadInput(const T* ptr, int64_t offset) { + return static_cast(ptr[offset]); + } + + static double LoadScale(const T* scale_data, + const float* scale_float_ptr, + int64_t offset) { + ORT_UNUSED_PARAMETER(scale_float_ptr); + return static_cast(scale_data[offset]); + } + + static double LoadBias(const T* bias_data, + const float* bias_float_ptr, + int64_t offset) { + ORT_UNUSED_PARAMETER(bias_float_ptr); + if (!bias_data) { + return 0.0; + } + return static_cast(bias_data[offset]); + } + + static void StoreOutput(T* dst, int64_t offset, double v) { + dst[offset] = static_cast(v); + } +}; + +struct HalfMath { + static double LoadInput(const MLFloat16* ptr, int64_t offset) { + return static_cast(static_cast(ptr[offset])); + } + + static double LoadScale(const MLFloat16* scale_data, + const float* scale_float_ptr, + int64_t offset) { + if (scale_float_ptr) { + return static_cast(scale_float_ptr[offset]); + } + return static_cast(static_cast(scale_data[offset])); + } + + static double LoadBias(const MLFloat16* bias_data, + const float* bias_float_ptr, + int64_t offset) { + if (bias_float_ptr) { + return static_cast(bias_float_ptr[offset]); + } + if (bias_data) { + return static_cast(static_cast(bias_data[offset])); + } + return 0.0; + } + + static void StoreOutput(MLFloat16* dst, int64_t offset, double v) { + dst[offset] = MLFloat16(static_cast(v)); + } +}; +// Shared generic implementation for LayerNorm with full NumPy-style broadcasting. +// DataT - storage type (float/double/MLFloat16) +// MathPolicy - policy that handles load/store/cast for DataT +// U - statistics output type (float, MLFloat16, etc.) +template +void ComputeJobGenericShared( + const DataT* X_data, + const DataT* scale_data, + const DataT* bias_data, + const ptrdiff_t task_idx, + const LayerNormParams& params, + const float* scale_float_ptr, + const float* bias_float_ptr, + float epsilon, + bool simplified, + DataT* Y_data, + U* mean_data, + U* inv_std_dev_data) { + const int64_t norm_size = params.norm_size; + const int64_t last_rank = params.last_rank; + + const DataT* p_input = X_data + task_idx * norm_size; + DataT* p_output = Y_data + task_idx * norm_size; + + // Compute mean and denom (same for all types, via MathPolicy). + double mean = 0.0; + double mean_sq = 0.0; + for (int64_t h = 0; h < norm_size; ++h) { + const double xv = MathPolicy::LoadInput(p_input, h); + mean += xv; + mean_sq += xv * xv; + } + + mean /= static_cast(norm_size); + const double denom = simplified + ? std::sqrt(mean_sq / norm_size + epsilon) + : std::sqrt(mean_sq / norm_size - mean * mean + epsilon); + + // Compute outer offsets for this logical row (same as before). + int64_t off_sc_row = 0; + int64_t off_bi_row = 0; + + const bool has_bias_any = (bias_data != nullptr) || (bias_float_ptr != nullptr); + + if (params.axis > 0) { + const auto& outer_strides = params.x_outer_strides; + + for (int64_t d = 0; d < params.axis; ++d) { + const size_t du = static_cast(d); + const int64_t dim = params.x_dims[du]; + const int64_t idx_d = (dim == 0) + ? 0 + : (task_idx / outer_strides[du]) % dim; + + off_sc_row += idx_d * params.sc_strides[du]; + if (has_bias_any) { + off_bi_row += idx_d * params.bi_strides[du]; + } + } + } + + // Prepare inner-dimension iteration (multi-dimensional idx for inner dims, + // plus optimized inner loop over the last dimension). + ORT_ENFORCE(last_rank > 0); + onnxruntime::InlinedVector idx(static_cast(last_rank), 0); + + const auto& x_inner_dims = params.x_inner_dims; + const auto& sc_inner_inc = params.sc_inner_inc; + const auto& bi_inner_inc = params.bi_inner_inc; + + const int64_t last_dim = x_inner_dims[static_cast(last_rank - 1)]; + ORT_ENFORCE(last_dim > 0); + ORT_ENFORCE(norm_size % last_dim == 0); + const int64_t num_chunks = norm_size / last_dim; + + const int64_t sc_last_stride = !sc_inner_inc.empty() ? sc_inner_inc.back() : 0; + const int64_t bi_last_stride = + (has_bias_any && !bi_inner_inc.empty()) ? bi_inner_inc.back() : 0; + + // Outer loop: iterate over "chunks" of the last dimension. + for (int64_t c = 0; c < num_chunks; ++c) { + int64_t off_sc = off_sc_row; + int64_t off_bi = off_bi_row; + + // Base offsets for all inner dims except the last. + for (int64_t d = 0; d < last_rank - 1; ++d) { + const size_t du = static_cast(d); + off_sc += idx[du] * sc_inner_inc[du]; + if (has_bias_any) { + off_bi += idx[du] * bi_inner_inc[du]; + } + } + + const int64_t base_h = c * last_dim; + + // Tight inner loop over the last dimension: compiler can vectorize this. + for (int64_t i = 0; i < last_dim; ++i) { + const int64_t h = base_h + i; + + const int64_t sc_offset = off_sc + i * sc_last_stride; + const int64_t bi_offset = off_bi + i * bi_last_stride; + + const double x = MathPolicy::LoadInput(p_input, h); + const double s = MathPolicy::LoadScale(scale_data, scale_float_ptr, sc_offset); + const double b = MathPolicy::LoadBias(bias_data, bias_float_ptr, bi_offset); + + const double y = simplified + ? (x / denom) * s + : ((x - mean) / denom) * s + b; + + MathPolicy::StoreOutput(p_output, h, y); + } + + // Update multi-dimensional index 'idx' for the next chunk + // (iterate backwards from the second-to-last dimension). + if (last_rank > 1) { + for (int64_t d = last_rank - 2; d >= 0; --d) { + const size_t du = static_cast(d); + if (++idx[du] < x_inner_dims[du]) { + break; + } + idx[du] = 0; + } + } + } + + // Write statistics outputs. + if (mean_data) { + WriteStat(mean_data, task_idx, mean); + } + if (inv_std_dev_data) { + WriteStat(inv_std_dev_data, task_idx, 1.0 / denom); + } +} +template +void ComputeJobGeneric( + const T* X_data, + const T* scale_data, + const T* bias_data, + const ptrdiff_t task_idx, + const LayerNormParams& params, + const float* scale_float_ptr, + const float* bias_float_ptr, + float epsilon, + bool simplified, + T* Y_data, + U* mean_data, + U* inv_std_dev_data) { + ORT_UNUSED_PARAMETER(scale_float_ptr); + ORT_UNUSED_PARAMETER(bias_float_ptr); + + using Policy = NormalizationMath; + ComputeJobGenericShared( + X_data, scale_data, bias_data, + task_idx, params, + nullptr, + nullptr, + epsilon, simplified, + Y_data, mean_data, inv_std_dev_data); +} +template +void ComputeJobGeneric( + const MLFloat16* X_data, + const MLFloat16* scale_data, + const MLFloat16* bias_data, + const ptrdiff_t task_idx, + const LayerNormParams& params, + const float* scale_float_ptr, + const float* bias_float_ptr, + float epsilon, + bool simplified, + MLFloat16* Y_data, + U* mean_data, + U* inv_std_dev_data) { + using Policy = HalfMath; + ComputeJobGenericShared( + X_data, scale_data, bias_data, + task_idx, params, + scale_float_ptr, bias_float_ptr, + epsilon, simplified, + Y_data, mean_data, inv_std_dev_data); +} void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { @@ -277,7 +527,13 @@ Status LayerNormImpl::ComputeWithoutContext( bool simplified, AllocatorPtr alloc) const { LayerNormParams params; - ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params)); + const bool has_bias = + !simplified && + (bias_data != nullptr || + (std::is_same_v && prepacked_bias_fp32_data_ != nullptr)); + + ORT_RETURN_IF_ERROR( + LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, has_bias, axis, params)); IAllocatorUniquePtr scale_fp32; IAllocatorUniquePtr bias_fp32; @@ -294,17 +550,42 @@ Status LayerNormImpl::ComputeWithoutContext( } } + // Resolve the float32 pointers for scale/bias (scf/bif) in the MLFloat16 case. + // For non-MLFloat16 types, these remain null and the original T* buffers are used. + const float* scf = nullptr; + const float* bif = nullptr; + + if constexpr (std::is_same_v) { + scf = prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() + : scale_fp32.get(); + + if (has_bias) { + bif = prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() + : (bias_fp32 ? bias_fp32.get() : nullptr); + } else { + bif = nullptr; + } + } + // Launch one normalization job per logical row in X. For each row we either: + // - use the generic NumPy-style broadcasting path, or + // - use the existing fast-path based on broadcast_param. concurrency::ThreadPool::TryBatchParallelFor( thread_pool, static_cast(params.num_rows), [&](ptrdiff_t task_idx) { - ComputeJob(X_data, scale_data, bias_data, task_idx, params.norm_size, params.broadcast_param, - prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(), - prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), - epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc); + if (params.use_generic_broadcast) { + ComputeJobGeneric(X_data, scale_data, bias_data, task_idx, params, + scf, bif, + epsilon, simplified, Y_data, mean_data, inv_std_dev_data); + } else { + ComputeJob(X_data, scale_data, bias_data, task_idx, + params.norm_size, params.broadcast_param, + scf, bif, + epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc); + } }, 0); return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 0d4fc5af68b4f..c97e70a550730 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -428,36 +428,377 @@ TEST(LayerNormTest, LayerNorm17_double) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider}); } -// Test normalize size shall be larger than 1. -TEST(LayerNormTest, LayerNorm_InvalidNormSize) { +TEST(LayerNormTest, LayerNorm_NormSize1_Valid) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); - std::vector dims{1, 3, 1}; test.AddInput("x", dims, {1.2416f, 0.946123f, 13.1685f}); test.AddInput("gamma", {1}, {-0.6953f}); test.AddInput("bias", {1}, {0.6435f}); test.AddAttribute("axis", 2); - test.AddOutput("output", dims, {-0.0516f, -5.5776f, -0.0518f}); - - RunTestOnCpuAndCuda(test, kLayerNormInvalidSize); + test.AddOutput("output", dims, {0.6435f, 0.6435f, 0.6435f}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); } -TEST(LayerNormTest, LayerNorm_InvalidScaleBias) { +TEST(LayerNormTest, LayerNorm_ValidScaleBias_Broadcast) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); - // as axis is 1, the scale and bias should have size 6 + // With axis = 1, scale and bias of shape {2} are NumPy-broadcastable + // to X.shape[axis:] = {3, 2}, so this configuration is now valid. std::vector dims{1, 3, 2}; test.AddInput("x", dims, {1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}); test.AddInput("gamma", {2}, {-0.6953f, 5.1824f}); test.AddInput("bias", {2}, {0.6435f, -0.3964f}); test.AddAttribute("axis", 1); - test.AddOutput("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}); + test.AddOutput("output", dims, + {1.063606f, -3.716114f, + 0.042961f, -4.087264f, + -0.639629f, -4.294445f}); + + // This configuration used to be rejected, but with generic NumPy-style + // broadcasting support it is now valid and should run successfully. + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Scalar_NoBias_Axis2) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {2, 2, 2}, x); + + test.AddInput("Scale", {}, {1.5f}, true); + test.AddOutput("Y", {2, 2, 2}, + { + -1.5f, + 1.5f, + -1.5f, + 1.5f, + -1.5f, + 1.5f, + -1.5f, + 1.5f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Scalar_Axis2) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {2, 2, 2}, x); + + test.AddInput("Scale", {}, {1.5f}, true); + + test.AddInput("B", {}, {0.1f}, true); + test.AddOutput("Y", {2, 2, 2}, + { + -1.4f, + 1.6f, + -1.4f, + 1.6f, + -1.4f, + 1.6f, + -1.4f, + 1.6f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Axis2) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {2, 2, 2}, x); + test.AddInput("Scale", {2}, {1.0f, 2.0f}, true); + test.AddInput("B", {2}, {0.0f, 0.5f}, true); + test.AddOutput("Y", {2, 2, 2}, + { + -1.0f, + 2.5f, + -1.0f, + 2.5f, + -1.0f, + 2.5f, + -1.0f, + 2.5f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_4D_OuterInnerBroadcast_Axis3) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 3); + std::vector x(1 * 2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {1, 2, 2, 2}, x); + test.AddInput("Scale", {1, 2, 1, 2}, + { + 1.0f, + 1.1f, + 1.2f, + 1.3f, + }, + true); + test.AddInput("B", {1, 2, 1, 2}, + { + 0.0f, + 0.5f, + 1.0f, + 1.5f, + }, + true); + test.AddOutput("Y", {1, 2, 2, 2}, + { + -1.0f, + 1.6f, + -1.0f, + 1.6f, + + -0.2f, + 2.8f, + -0.2f, + 2.8f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Scalar_NoBias) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {2, 2, 2}, x); + test.AddInput("Scale", {}, {1.5f}, true); + test.AddOutput("Y", {2, 2, 2}, + { + -1.5f, + 1.5f, + -1.5f, + 1.5f, + -1.5f, + 1.5f, + -1.5f, + 1.5f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} +TEST(LayerNormTest, LayerNorm_Scale_Bias_Scalar) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {2, 2, 2}, x); + test.AddInput("Scale", {}, {1.5f}, true); + test.AddInput("B", {}, {0.1f}, true); + + test.AddOutput("Y", {2, 2, 2}, + { + -1.4f, + 1.6f, + -1.4f, + 1.6f, + -1.4f, + 1.6f, + -1.4f, + 1.6f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_PerLastDim) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); - // CPU and CUDA EPs have check for unexpected scale or bias sizes. Exclude other EPs with a LayerNormalization - // implementation for which we don't control the check or error message. - RunTestOnCpuAndCuda(test, kLayerNormInputShapeMismatchError); + std::vector x(2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {2, 2, 2}, x); + + test.AddInput("Scale", {2}, {1.0f, 2.0f}, true); + + test.AddInput("B", {2}, {0.0f, 0.5f}, true); + + test.AddOutput("Y", {2, 2, 2}, + { + -1.0f, + 2.5f, + -1.0f, + 2.5f, + -1.0f, + 2.5f, + -1.0f, + 2.5f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_4D_OuterInnerBroadcast) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 3); + + std::vector x(1 * 2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[static_cast(i)] = static_cast(i); + } + test.AddInput("X", {1, 2, 2, 2}, x); + + test.AddInput("Scale", {1, 2, 1, 2}, + { + 1.0f, + 1.1f, + 1.2f, + 1.3f, + }, + true); + + test.AddInput("B", {1, 2, 1, 2}, + { + 0.0f, + 0.5f, + 1.0f, + 1.5f, + }, + true); + test.AddOutput("Y", {1, 2, 2, 2}, + { + -1.0f, + 1.6f, + -1.0f, + 1.6f, + + -0.2f, + 2.8f, + -0.2f, + 2.8f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} +TEST(LayerNormTest, LayerNorm_NormSize1_NoBias) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("axis", 2); + test.AddAttribute("epsilon", 1e-5f); + + std::vector x = { + 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f}; + test.AddInput("X", {2, 3, 1}, x); + test.AddInput("Scale", {1}, {1.0f}); + std::vector expected = { + 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f}; + + test.AddOutput("Y", {2, 3, 1}, expected); + test.AddOutput("Mean", {2, 3, 1}, + {1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f}); + + float inv_std = 1.0f / sqrtf(1e-5f); + test.AddOutput("InvStdDev", {2, 3, 1}, + {inv_std, inv_std, inv_std, + inv_std, inv_std, inv_std}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} +TEST(LayerNormTest, LayerNorm_NormSize1_WithBiasScale) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("axis", 2); + test.AddAttribute("epsilon", 1e-5f); + test.AddInput("X", {1, 2, 1}, {10.0f, 20.0f}); + test.AddInput("Scale", {1}, {2.0f}); + test.AddInput("Bias", {1}, {5.0f}); + test.AddOutput("Y", {1, 2, 1}, {5.0f, 5.0f}); + test.AddOutput("Mean", {1, 2, 1}, {10.0f, 20.0f}); + float inv_std = 1.0f / sqrtf(1e-5f); + test.AddOutput("InvStdDev", {1, 2, 1}, {inv_std, inv_std}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(LayerNormTest, LayerNorm_Scale_Broadcast_Inner_Mixed) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 1); + std::vector dims{1, 2, 4}; + std::vector x = { + 0.0f, 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, 7.0f}; + test.AddInput("X", dims, x); + std::vector scale = {1.0f, 0.5f, 1.0f, 0.5f}; + test.AddInput("Scale", {1, 4}, scale); + std::vector expected_y = { + -1.527524f, -0.545544f, -0.654653f, -0.109109f, + 0.218218f, 0.327327f, 1.091088f, 0.763762f}; + test.AddOutput("Y", dims, expected_y); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); } #if defined(USE_DNNL) @@ -478,6 +819,7 @@ TEST(LayerNormTest, LayerNorm17_Scale_Bias_bfloat16) { test.AddOutput("output", dims, MakeBFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); test.Run(); } + #endif // USE_DNNL } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/rms_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/rms_norm_op_test.cc index 4bdd3ea5adaff..d16e5eee3b50d 100644 --- a/onnxruntime/test/providers/cpu/nn/rms_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/rms_norm_op_test.cc @@ -68,5 +68,652 @@ TEST(RMSNormalizationOpTest, RMSNorm_Scale_Float16) { kNnapiExecutionProvider, kQnnExecutionProvider}); } +TEST(RMSNormalizationOpTest, RMSNorm_Scale_Scalar_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(2 * 5 * 3); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {2, 5, 3}, x); + test.AddInput("scale", {}, {1.5f}, true); + test.AddOutput( + "Y", {2, 5, 3}, + {0.0000f, 1.1619f, 2.3238f, + 1.1023f, 1.4697f, 1.8371f, + 1.2771f, 1.4899f, 1.7027f, + 1.3455f, 1.4950f, 1.6445f, + 1.3819f, 1.4971f, 1.6122f, + + 1.4044f, 1.4981f, 1.5917f, + 1.4197f, 1.4986f, 1.5775f, + 1.4308f, 1.4990f, 1.5671f, + 1.4392f, 1.4992f, 1.5592f, + 1.4458f, 1.4994f, 1.5529f}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1x1x1_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[i] = static_cast(i); + } + test.AddInput("X", {2, 2, 2}, x); + + test.AddInput("scale", {1, 1, 1}, {1.0f}, true); + + test.AddOutput("Y", {2, 2, 2}, + { + 0.0000f, + 1.4142f, + 0.7845f, + 1.1767f, + + 0.8835f, + 1.1043f, + 0.9204f, + 1.0738f, + }); + + test.SetOutputAbsErr("Y", 1e-4f); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_Vector3_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + std::vector x(2 * 5 * 3); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {2, 5, 3}, x); + + test.AddInput("scale", {3}, {1.5f, 1.5f, 1.5f}, true); + + test.AddOutput("Y", {2, 5, 3}, + {0.0000f, 1.1619f, 2.3238f, + 1.1023f, 1.4697f, 1.8371f, + 1.2771f, 1.4899f, 1.7027f, + 1.3455f, 1.4950f, 1.6445f, + 1.3819f, 1.4971f, 1.6122f, + + 1.4044f, 1.4981f, 1.5917f, + 1.4197f, 1.4986f, 1.5775f, + 1.4308f, 1.4990f, 1.5671f, + 1.4392f, 1.4992f, 1.5592f, + 1.4458f, 1.4994f, 1.5529f}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1x1x3_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + std::vector x(2 * 5 * 3); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {2, 5, 3}, x); + + test.AddInput("scale", {1, 1, 3}, {1.5f, 1.5f, 1.5f}, true); + + test.AddOutput("Y", {2, 5, 3}, + {0.0000f, 1.1619f, 2.3238f, + 1.1023f, 1.4697f, 1.8371f, + 1.2771f, 1.4899f, 1.7027f, + 1.3455f, 1.4950f, 1.6445f, + 1.3819f, 1.4971f, 1.6122f, + + 1.4044f, 1.4981f, 1.5917f, + 1.4197f, 1.4986f, 1.5775f, + 1.4308f, 1.4990f, 1.5671f, + 1.4392f, 1.4992f, 1.5592f, + 1.4458f, 1.4994f, 1.5529f}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_Bx1x3_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-5f); + test.AddAttribute("axis", 2); + + std::vector x(3 * 2 * 3); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[i] = static_cast(i); + } + test.AddInput("X", {3, 2, 3}, x); + + test.AddInput( + "scale", {3, 1, 3}, + { + 1.0f, + 1.0f, + 1.0f, + 1.2f, + 1.2f, + 1.2f, + 1.4f, + 1.4f, + 1.4f, + }, + true); + + test.AddOutput( + "Y", {3, 2, 3}, + { + 0.0000f, + 0.7746f, + 1.5492f, + 0.7348f, + 0.9798f, + 1.2247f, + + 1.0216f, + 1.1919f, + 1.3622f, + 1.0764f, + 1.1960f, + 1.3156f, + + 1.2898f, + 1.3972f, + 1.5047f, + 1.3108f, + 1.3982f, + 1.4856f, + }); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1xSx3_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-5f); + test.AddAttribute("axis", 2); + + std::vector x(2 * 4 * 3); + for (int i = 0; i < static_cast(x.size()); ++i) { + x[i] = static_cast(i); + } + test.AddInput("X", {2, 4, 3}, x); + + test.AddInput("scale", + {1, 4, 3}, + { + 1.1f, + 1.1f, + 1.1f, + 1.2f, + 1.2f, + 1.2f, + 1.3f, + 1.3f, + 1.3f, + 1.4f, + 1.4f, + 1.4f, + }, + true); + + test.AddOutput( + "Y", {2, 4, 3}, + { + 0.0000f, + 0.8521f, + 1.7041f, + 0.8818f, + 1.1758f, + 1.4697f, + 1.1068f, + 1.2912f, + 1.4757f, + 1.2558f, + 1.3954f, + 1.5349f, + + 1.0134f, + 1.0978f, + 1.1823f, + 1.1235f, + 1.1984f, + 1.2733f, + 1.2304f, + 1.2988f, + 1.3672f, + 1.3354f, + 1.3990f, + 1.4626f, + }); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_NoBroadcast_BxSx3_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + std::vector x(2 * 5 * 3); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {2, 5, 3}, x); + + std::vector scale(2 * 5 * 3, 1.5f); + test.AddInput("scale", {2, 5, 3}, scale, true); + + test.AddOutput("Y", {2, 5, 3}, + {0.0000f, 1.1619f, 2.3238f, + 1.1023f, 1.4697f, 1.8371f, + 1.2771f, 1.4899f, 1.7027f, + 1.3455f, 1.4950f, 1.6445f, + 1.3819f, 1.4971f, 1.6122f, + + 1.4044f, 1.4981f, 1.5917f, + 1.4197f, 1.4986f, 1.5775f, + 1.4308f, 1.4990f, 1.5671f, + 1.4392f, 1.4992f, 1.5592f, + 1.4458f, 1.4994f, 1.5529f}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1xCx1x1_Axis1) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 1); + std::vector x(1 * 4 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {1, 4, 2, 2}, x); + + test.AddInput("scale", {1, 4, 1, 1}, + {1.1f, 1.2f, 1.3f, 1.4f}, + true); + + test.AddOutput("Y", {1, 4, 2, 2}, {0.0000000, 0.1249516, 0.2499032, 0.3748548, + + 0.5452434, 0.6815542, 0.8178651, 0.9541759, + + 1.1813605, 1.3290305, 1.4767007, 1.6243708, + + 1.9083518, 2.0673811, 2.2264102, 2.3854396}); + test.SetOutputAbsErr("Y", 1e-4f); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1xCx1_Axis1) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 1); + std::vector x(2 * 3 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {2, 3, 2}, x); + test.AddInput("scale", {1, 3, 1}, {1.0f, 1.2f, 1.4f}, true); + test.AddOutput( + "Y", {2, 3, 2}, + {0.0f, 0.33028895f, + 0.79269350f, 1.18904030f, + 1.84961808f, 2.31202269f, + 0.69205177f, 0.80739373f, + 1.10728300f, 1.24569333f, + 1.61478746f, 1.77626622f}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1x3x2x1_Axis1) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 1); + std::vector x(1 * 3 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) + x[i] = static_cast(i); + test.AddInput("X", {1, 3, 2, 2}, x); + test.AddInput( + "scale", {1, 3, 2, 1}, + { + 1.0f, + 1.1f, + 1.2f, + 1.3f, + 1.4f, + 1.5f, + }, + true); + test.AddOutput( + "Y", {1, 3, 2, 2}, + {0.0f, 0.15399808f, + 0.33879578f, 0.50819367f, + 0.73919082f, 0.92398852f, + 1.20118499f, 1.40138257f, + 1.72477841f, 1.94037580f, + 2.30997109f, 2.54096842f}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1xSx1xW_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x(1 * 2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {1, 2, 2, 2}, x); + + test.AddInput("scale", {1, 2, 1, 2}, + {1.0f, 1.2f, + 1.4f, 1.6f}, + true); + test.AddOutput("Y", {1, 2, 2, 2}, + {0.0000f, 0.6414f, + 1.0690f, 1.9243f, + + 0.9978f, 1.4254f, + 1.4967f, 1.9956f}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1x1xHx1_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + std::vector x(1 * 2 * 2 * 2); + for (int i = 0; i < static_cast(x.size()); ++i) x[i] = static_cast(i); + test.AddInput("X", {1, 2, 2, 2}, x); + + test.AddInput("scale", {1, 1, 2, 1}, {1.0f, 1.3f}, true); + + test.AddOutput("Y", {1, 2, 2, 2}, + {0.0000f, 0.5345f, + 1.3898f, 2.0846f, + + 0.7127f, 0.8909f, + 1.3898f, 1.6214f}); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1x1x1xW_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + std::vector x(1 * 2 * 2 * 3); + for (int i = 0; i < (int)x.size(); ++i) x[i] = (float)i; + test.AddInput("X", {1, 2, 2, 3}, x); + + test.AddInput("scale", {1, 1, 1, 3}, {1.0f, 1.2f, 1.4f}, true); + + test.AddOutput("Y", {1, 2, 2, 3}, + {0.0000f, 0.3963f, 0.9248f, + 0.9909f, 1.5854f, 2.3120f, + 0.6921f, 0.9689f, 1.2918f, + 1.0381f, 1.3841f, 1.7763f}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1xSx1x1_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + std::vector x(1 * 3 * 2 * 2); + for (int i = 0; i < (int)x.size(); ++i) x[i] = (float)i; + test.AddInput("X", {1, 3, 2, 2}, x); + + test.AddInput("scale", {1, 3, 1, 1}, {1.0f, 1.2f, 1.4f}, true); + + test.AddOutput("Y", {1, 3, 2, 2}, + {0.0000f, 0.5345f, 1.0690f, 1.6036f, + 0.8552f, 1.0690f, 1.2829f, 1.4967f, + 1.1709f, 1.3172f, 1.4636f, 1.6099f}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_Bx1x1xW_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + std::vector x(2 * 1 * 2 * 2); + for (int i = 0; i < (int)x.size(); ++i) x[i] = (float)i; + test.AddInput("X", {2, 1, 2, 2}, x); + + test.AddInput("scale", {2, 1, 1, 2}, {1.0f, 1.1f, 1.3f, 1.4f}, true); + + test.AddOutput("Y", {2, 1, 2, 2}, + {0.0000f, 0.5880f, 1.0690f, 1.7639f, + 0.9265f, 1.2472f, 1.3898f, 1.7461f}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1x1xHxW_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + + test.AddInput("X", {1, 2, 2, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + + test.AddInput("scale", {1, 1, 2, 3}, + {1.0f, 1.1f, 1.2f, + 1.3f, 1.4f, 1.5f}, + true); + + test.AddOutput("Y", {1, 2, 2, 3}, + {0.0000f, 0.3633f, 0.7927f, 1.2881f, 1.8496f, 2.4772f, + 0.6921f, 0.8881f, 1.1073f, 1.3495f, 1.6148f, 1.9031f}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1xSx1xW_AxisNeg2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", -2); + + std::vector x(1 * 2 * 2 * 2); + for (int i = 0; i < (int)x.size(); ++i) x[i] = (float)i; + test.AddInput("X", {1, 2, 2, 2}, x); + + test.AddInput("scale", {1, 2, 1, 2}, + {1.0f, 1.2f, 1.4f, 1.6f}, true); + + test.AddOutput("Y", {1, 2, 2, 2}, + {0.0000f, 0.6414f, + 1.0690f, 1.9243f, + + 0.9978f, 1.4254f, + 1.4967f, 1.9956f}); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_1xSx1x1xC_Axis3) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 3); + + const int B = 1, S = 2, H = 2, W = 2, C = 3; + std::vector x(B * S * H * W * C); + for (int i = 0; i < (int)x.size(); ++i) x[i] = (float)i; + test.AddInput("X", {B, S, H, W, C}, x); + + test.AddInput("scale", {1, S, 1, 1, C}, + {1.0f, 1.1f, 1.2f, + 1.3f, 1.4f, 1.5f}, + true); + + test.AddOutput("Y", {B, S, H, W, C}, + { + 0.0000f, + 0.3633f, + 0.7927f, + 0.9909f, + 1.4533f, + 1.9817f, + 0.6921f, + 0.8881f, + 1.1073f, + 1.0381f, + 1.2688f, + 1.5225f, + 1.0685f, + 1.2466f, + 1.4383f, + 1.3356f, + 1.5342f, + 1.7465f, + 1.1375f, + 1.2931f, + 1.4584f, + 1.3271f, + 1.4973f, + 1.6771f, + }); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +TEST(RMSNormalizationOpTest, RMSNorm_Scale_Float16_OuterInnerBroadcast_Axis1) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 1); + std::vector x_f(24); + for (int i = 0; i < 24; ++i) x_f[i] = static_cast(i); + + std::vector x_half(x_f.size()); + for (size_t i = 0; i < x_f.size(); ++i) + x_half[i] = MLFloat16(x_f[i]); + + test.AddInput("X", {2, 3, 4}, x_half); + + std::vector scale_f = {1.0f, 2.0f, 3.0f}; + std::vector scale_half(scale_f.size()); + for (size_t i = 0; i < scale_f.size(); ++i) + scale_half[i] = MLFloat16(scale_f[i]); + + test.AddInput("scale", {1, 3, 1}, scale_half, true); + + std::vector y_f = { + 0.0000f, 0.1540f, 0.3080f, 0.4620f, + 1.2320f, 1.5400f, 1.8480f, 2.1560f, + 3.6960f, 4.1579f, 4.6199f, 5.0819f, + 0.6728f, 0.7288f, 0.7849f, 0.8409f, + 1.7940f, 1.9061f, 2.0183f, 2.1304f, + 3.3638f, 3.5319f, 3.7001f, 3.8683f}; + + std::vector y_half(y_f.size()); + for (size_t i = 0; i < y_f.size(); ++i) + y_half[i] = MLFloat16(y_f[i]); + + test.AddOutput("Y", {2, 3, 4}, y_half); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} +TEST(RMSNormalizationOpTest, RMSNorm_Scale_Float16_OuterBroadcast_BxSx1_Axis2) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 2); + std::vector x_f(2 * 2 * 3); + for (int i = 0; i < static_cast(x_f.size()); ++i) { + x_f[static_cast(i)] = static_cast(i); + } + std::vector x_half(x_f.size()); + for (size_t i = 0; i < x_f.size(); ++i) { + x_half[i] = MLFloat16(x_f[i]); + } + test.AddInput("X", {2, 2, 3}, x_half); + std::vector scale_f = { + 1.0f, 2.0f, + 3.0f, 4.0f}; + std::vector scale_half(scale_f.size()); + for (size_t i = 0; i < scale_f.size(); ++i) { + scale_half[i] = MLFloat16(scale_f[i]); + } + test.AddInput("scale", {2, 2, 1}, scale_half, true); + std::vector y_f = { + 0.0000f, 0.7746f, 1.5492f, + 1.4697f, 1.9596f, 2.4495f, + + 2.5541f, 2.9798f, 3.4055f, + 3.5881f, 3.9867f, 4.3854f}; + + std::vector y_half(y_f.size()); + for (size_t i = 0; i < y_f.size(); ++i) { + y_half[i] = MLFloat16(y_f[i]); + } + + test.AddOutput("Y", {2, 2, 3}, y_half); + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} +TEST(RMSNormalizationOpTest, RMSNorm_Scale_Broadcast_Inner_Mixed) { + OpTester test("RMSNormalization", 23); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("axis", 1); + std::vector dims{1, 2, 4}; + std::vector x = { + 0.0f, 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, 7.0f}; + test.AddInput("X", dims, x); + std::vector scale = {1.0f, 0.5f, 1.0f, 0.5f}; + test.AddInput("Scale", {1, 4}, scale); + std::vector expected = { + 0.0f, + 0.119527f, + 0.478108f, + 0.358581f, + 0.956216f, + 0.597635f, + 1.434324f, + 0.836689f}; + + test.AddOutput("Y", dims, expected); + + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + } // namespace test } // namespace onnxruntime