Conversation
exla/lib/exla/defn.ex
Outdated
| Value.all_gather( | ||
| [tensor], | ||
| expr_to_typespec(ans), | ||
| all_gather_dim, | ||
| replica_groups, | ||
| use_global_device_ids, | ||
| Keyword.take(opts, [:channel_id]) | ||
| ) | ||
| |> hd() |
There was a problem hiding this comment.
Let's hard match for now instead of hd (i.e. [result] = Value...)
And then add a comment that we might want to surface all_gather as an operation that takes a container of operands instead of a single one.
exla/lib/exla/mlir/value.ex
Outdated
|
|
||
| attributes = | ||
| if opts[:channel_id] do | ||
| attributes ++ [channel_id: attr_i64(opts[:channel_id])] |
There was a problem hiding this comment.
Let's use Keyword.put instead of ++
exla/lib/exla/mlir/value.ex
Outdated
| if opts[:channel_id] do | ||
| attributes ++ [channel_id: attr_i64(opts[:channel_id])] | ||
| else | ||
| attributes end |
exla/lib/exla/mlir/value.ex
Outdated
| end | ||
| end | ||
|
|
||
| def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do |
There was a problem hiding this comment.
how about channel_id being a required argument and we just pass the value directly?
nx/lib/nx/defn/evaluator.ex
Outdated
| if op == :all_gather and not function_exported?(mod, :all_gather, 3) do | ||
| raise ArgumentError, | ||
| "all_gather/3 is not supported by backend #{inspect(mod)}." | ||
| end |
There was a problem hiding this comment.
If we remove this, do we have a test verifying this raise? Also, I believe this is already checked elsewhere.
There was a problem hiding this comment.
If it's not, it seems to me that this check should be more general
| _all_gather_dim = opts[:all_gather_dim] | ||
| replica_groups = opts[:replica_groups] | ||
|
|
||
| # Calculate group size (number of replicas per group) | ||
| _group_size = | ||
| case replica_groups do | ||
| [first_group | _] -> length(first_group) | ||
| [] -> 1 | ||
| end | ||
|
|
||
| # Calculate output shape by multiplying the gather dimension by group_size | ||
| input_shape = tensor.shape | ||
| output_shape = | ||
| input_shape | ||
| # |> Tuple.to_list() | ||
| # |> List.update_at(all_gather_dim, &(&1 * group_size)) | ||
| # |> List.to_tuple() | ||
|
|
||
| # Create output tensor with the new shape |
There was a problem hiding this comment.
There are a few unused values here due to the stray comments that should all be removed. Also, just pass tensor as out directly
nx/lib/nx/defn/kernel.ex
Outdated
|
|
||
| * `tensor` - The input tensor to gather | ||
| * `all_gather_dim` - The dimension along which to gather | ||
| * `replica_groups` - 2D list defining how replicas are grouped (required) |
There was a problem hiding this comment.
I'm not sure if this is the terminology we want to surface here. For now, let's make the function all_gather(tensor, opts) and defer the documentation of opts to the specific backend or compiler.
And in EXLA we should add a new section to the moduledoc of EXLA describing Sharding
polvalente
left a comment
There was a problem hiding this comment.
This is looking great! I think we need more tests in both Nx and EXLA
Implements Nx.Defn.Kernel.all_gather/2 to gather sharded tensor data across mesh partitions during distributed execution.
Changes
Nx
Add all_gather/2 in defn/kernel.ex and defn/expr.ex with sharding semantics
Add evaluator support for all_gather in defn/evaluator.ex
EXLA
Lower all_gather to stablehlo.all_gather in defn.ex and mlir/value.ex
Test
EXLA.Defn.ShardingTest: "generates correct MLIR with all_gather" checks MLIR generation and shard_jit output across a 2×2 mesh along axis 0 and 1