diff --git a/CHANGES.md b/CHANGES.md index 7f24f12a..7bfc8cfc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -30,6 +30,7 @@ We're closing 8 user-reported issues or feature requests and are totalling 30 co ### Nx +- Fix fancy slicing so `L [...]` with full axis length correctly supports permutations and duplicates; only identity order is a no‑op (#152, @Arsalaan-Alam) - Fix einsum output axis ordering for free axes (e.g., `i,jk->jki`, `ij,klj->kli`) by correcting final transpose permutation and intermediate left-axis reordering. (@tmattio) - Add `Nx_io.Cache_dir` module with consolidated cache directory utilities respecting `RAVEN_CACHE_ROOT`, `XDG_CACHE_HOME`, and `HOME` fallback, replacing project-specific cache logic across the whole raven ecosystem (#134, @Arsalaan-Alam) - Add `Nx_io.save_txt` / `Nx_io.load_txt` with NumPy-compatible formatting, comments, and dtype support (#120, @six-shot) diff --git a/nx/lib/core/frontend.ml b/nx/lib/core/frontend.ml index 77afd369..9a0a1e23 100644 --- a/nx/lib/core/frontend.ml +++ b/nx/lib/core/frontend.ml @@ -2969,7 +2969,18 @@ module Make (B : Backend_intf.S) = struct let dim_size = (shape tensor).(working_dim) in let indices = indices_of_spec dim_size spec in let tensor' = - if List.length indices = dim_size then tensor + (* If indices cover the whole dimension in identity + order, this is a no-op. Otherwise we must gather + (to support permutations and duplicates). *) + let is_range n indices = + let rec traverse id = function + | [] -> id = n + | hd :: tl when hd = id -> traverse (succ id) tl + | _ -> false + in + traverse 0 indices + in + if is_range dim_size indices then tensor else if List.length indices = 0 then ( (* Empty slice - create tensor with 0 size in this dimension *) diff --git a/nx/test/test_nx_indexing.ml b/nx/test/test_nx_indexing.ml index 906ffc85..43c2d1ec 100644 --- a/nx/test/test_nx_indexing.ml +++ b/nx/test/test_nx_indexing.ml @@ -78,6 +78,18 @@ let test_index_mixed () = [| 20.; 21.; 22.; 23.; 24.; 30.; 31.; 32.; 33.; 34. |] indexed +(* Fancy indexing along axis 0 should reorder rows (and support duplicates) *) +let test_index_idx_reorder_rows () = + let t = Nx.create Nx.float32 [| 3; 2 |] (Array.init 6 float_of_int) in + let indexed = Nx.slice [ Nx.L [ 2; 0; 1 ]; Nx.A ] t in + check_t "index idx reorder rows" [| 3; 2 |] [| 4.; 5.; 0.; 1.; 2.; 3. |] indexed + +let test_index_idx_duplicate_rows () = + let t = Nx.create Nx.float32 [| 3; 2 |] (Array.init 6 float_of_int) in + let indexed = Nx.slice [ Nx.L [ 1; 1; 0 ]; Nx.A ] t in + check_t "index idx duplicate rows" [| 3; 2 |] [| 2.; 3.; 2.; 3.; 0.; 1. |] + indexed + (* Note: `new_ and `mask require implementation *) (* let test_index_new_axis () = let t = Nx.create Nx.float32 [| 3; 4 |] (Array.init 12 float_of_int) in @@ -396,6 +408,8 @@ let index_tests = ("index idx", `Quick, test_index_idx); ("index idx repeated", `Quick, test_index_idx_repeated); ("index mixed", `Quick, test_index_mixed); + ("index idx reorder rows", `Quick, test_index_idx_reorder_rows); + ("index idx duplicate rows", `Quick, test_index_idx_duplicate_rows); ("set_slice at", `Quick, test_set_slice_at); ("set_slice rng", `Quick, test_set_slice_rng); ("set_slice idx", `Quick, test_set_slice_idx);