Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 149 additions & 12 deletions onnxruntime/core/providers/cpu/nn/layer_norm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
#include "core/framework/tensor_shape.h"
#include "core/common/status.h"
#include "core/common/narrow.h"
#include "core/common/inlined_containers.h"

namespace onnxruntime {

constexpr const char* kLayerNormInputShapeMismatchError =
"Size of scale and bias (if provided) must match X.shape[axis:], "
"or scale and bias (with same shape) can be broadcasted to X when axis is 2.";
"Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it.";

constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";

Expand All @@ -23,15 +23,31 @@ struct LayerNormParams {
int64_t scale_size;
int64_t bias_size;
int64_t broadcast_param;
bool use_generic_broadcast{false}; // true: full NumPy-style broadcast; false: legacy broadcast_param path
onnxruntime::InlinedVector<int64_t, 8> x_dims;
onnxruntime::InlinedVector<int64_t, 8> x_inner_dims; // X.shape[axis:]
onnxruntime::InlinedVector<int64_t, 8> sc_dims;
onnxruntime::InlinedVector<int64_t, 8> bi_dims;
onnxruntime::InlinedVector<int64_t, 8> sc_strides;
onnxruntime::InlinedVector<int64_t, 8> bi_strides;
int64_t axis{0};
int64_t last_rank{0};
onnxruntime::InlinedVector<int64_t, 8> sc_inner_inc; // scale strides for inner dims [axis..]
onnxruntime::InlinedVector<int64_t, 8> bi_inner_inc; // bias strides for inner dims [axis..]
onnxruntime::InlinedVector<int64_t, 8> sc_outer_inc; // how much the scale pointer moves (stride) when an outer-dimension index of X changes (dims 0..axis-1)
onnxruntime::InlinedVector<int64_t, 8> bi_outer_inc; // how much the bias pointer moves (stride) when an outer-dimension index of X changes (dims 0..axis-1)
onnxruntime::InlinedVector<int64_t, 8> x_outer_strides; // X strides for outer dims [0..axis-1]
};

// We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns.
// Fast-path broadcasting for axis = 2:
// When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S).
// We support scale and bias shape like below:
// We support the following scale/bias shapes in this path:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
// For all other NumPy-broadcastable shapes we fall back to the generic
// broadcasting path (use_generic_broadcast = true) and ignore broadcast_param.

// Below is a macro to compute the offset for scale and bias data for a row of X.
#ifndef LAYER_NORM_SCALE_BIAS_OFFSET
Expand All @@ -48,30 +64,151 @@ class LayerNormHelper {
bool has_bias,
int64_t axis,
LayerNormParams& params) {
// Initialize basic layout parameters: how many rows we have and how many elements
// are normalized per row, as well as the total scale/bias sizes.
params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
params.scale_size = scale_shape.Size();
params.bias_size = bias_shape.Size();
params.bias_size = has_bias ? bias_shape.Size() : 0;

params.broadcast_param = 0;
params.axis = axis;

if (params.norm_size <= 1) {
params.broadcast_param = kLayerNormInvalidInput;
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size);
} else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) {
params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
if (params.broadcast_param == kLayerNormInvalidInput) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" X.shape=", x_shape,
" scale.shape=", scale_shape,
" bias.shape=", bias_shape,
" and axis=", axis);
// Try to encode simple (B, S, ...) layouts into broadcast_param so that the
// fast-path can be used. If this fails, broadcast_param will be set to
// kLayerNormInvalidInput and we may fall back to generic broadcasting later.
}

const size_t xr = x_shape.NumDimensions();
const size_t sr = scale_shape.NumDimensions();
const size_t br = has_bias ? bias_shape.NumDimensions() : 0;
params.x_dims.clear();
params.x_dims.reserve(xr);
for (size_t i = 0; i < xr; ++i) params.x_dims.push_back(x_shape.GetDims()[i]);

// Right-align the scale (and bias) shape to match X's rank, filling leading
// dimensions with 1 so that NumPy-style broadcasting rules can be applied.
params.sc_dims.clear();
params.sc_dims.resize(xr, 1);
{
for (size_t i = 0; i < sr; ++i) {
params.sc_dims[xr - 1 - i] = scale_shape.GetDims()[sr - 1 - i];
}
}

params.bi_dims.clear();
if (has_bias) {
params.bi_dims.resize(xr, 1);
for (size_t i = 0; i < br; ++i) {
params.bi_dims[xr - 1 - i] = bias_shape.GetDims()[br - 1 - i];
}
}
// Validate that scale and bias shapes are NumPy-broadcastable to X.
// If not, we fail early with a clear shape mismatch error.
const bool sc_ok = IsNumpyBroadcastable(params.sc_dims, params.x_dims);
const bool bi_ok = !has_bias || IsNumpyBroadcastable(params.bi_dims, params.x_dims);
if (!sc_ok || !bi_ok) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" X.shape=", x_shape,
" scale.shape=", scale_shape,
" bias.shape=", bias_shape,
" and axis=", axis);
}
// Cache the inner dimensions X.shape[axis:] that are normalized together
// for each logical row.
params.last_rank = onnxruntime::narrow<int64_t>(xr) - axis;
params.x_inner_dims.clear();
params.x_inner_dims.reserve(params.last_rank > 0 ? static_cast<size_t>(params.last_rank) : 0);
for (size_t i = static_cast<size_t>(axis); i < xr; ++i) {
params.x_inner_dims.push_back(params.x_dims[i]);
}

params.sc_strides = MakeStrides(params.sc_dims);
params.bi_strides.clear();
if (has_bias) {
params.bi_strides = MakeStrides(params.bi_dims);
}

// Precompute how scale/bias advance along the inner dimensions [axis..]:
// these increments are used inside the per-row normalization loop.
params.sc_inner_inc.clear();
params.bi_inner_inc.clear();
for (size_t i = static_cast<size_t>(axis); i < xr; ++i) {
params.sc_inner_inc.push_back(params.sc_strides[i]);
if (has_bias) {
params.bi_inner_inc.push_back(params.bi_strides[i]);
}
}
// Compute strides for X over the outer dimensions [0..axis-1],
// used to locate the base address of each logical row in X.
params.x_outer_strides.clear();
params.x_outer_strides.resize(static_cast<size_t>(axis), 1);
if (axis > 1) {
for (int64_t d = axis - 2; d >= 0; --d) {
const size_t du = static_cast<size_t>(d);
params.x_outer_strides[du] =
params.x_outer_strides[du + 1] * params.x_dims[du + 1];
}
}
// Detect whether scale/bias depend on any outer dimensions [0..axis-1].
// If any outer stride is non-zero, scale/bias are not purely "inner-only"
// and the simple fast-path based on broadcast_param is not sufficient.
params.sc_outer_inc.clear();
params.bi_outer_inc.clear();
for (int64_t i = 0; i < axis; ++i) {
params.sc_outer_inc.push_back(params.sc_strides[static_cast<size_t>(i)]);
params.bi_outer_inc.push_back(has_bias ? params.bi_strides[static_cast<size_t>(i)] : 0);
}

bool outer_dep = false;
for (int64_t i = 0; i < axis; ++i) {
if (params.sc_outer_inc[static_cast<size_t>(i)] != 0 ||
(has_bias && params.bi_outer_inc[static_cast<size_t>(i)] != 0)) {
outer_dep = true;
break;
}
}
// Enable the generic NumPy-style broadcasting path if either:
// - the fast-path cannot represent this shape (broadcast_param is invalid), or
// - scale/bias have any dependency on outer dimensions.
params.use_generic_broadcast = outer_dep || (params.broadcast_param == kLayerNormInvalidInput);

return Status::OK();
}

private:
static bool IsNumpyBroadcastable(gsl::span<const int64_t> a,
gsl::span<const int64_t> b) {
ORT_ENFORCE(a.size() == b.size());
for (size_t k = 0; k < a.size(); ++k) {
const int64_t ak = a[k];
const int64_t bk = b[k];
if (!(ak == 1 || ak == bk)) {
return false;
}
}
return true;
}
static InlinedVector<int64_t, 8> MakeStrides(const InlinedVector<int64_t, 8>& dims) {
InlinedVector<int64_t, 8> strides(dims.size(), 0);
if (dims.empty()) return strides;

int64_t running = 1;
for (ptrdiff_t i = dims.size() - 1; i >= 0; --i) {
size_t idx = static_cast<size_t>(i);
strides[idx] = (dims[idx] == 1) ? 0 : running;
running *= std::max<int64_t>(1, dims[idx]);
}

return strides;
}

static int64_t GetBroadcastParam(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape* bias_shape,
Expand Down
Loading
Loading