Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ We're closing 8 user-reported issues or feature requests and are totalling 30 co

### Nx

- Fix fancy slicing so `L [...]` with full axis length correctly supports permutations and duplicates; only identity order is a no‑op (#152, @Arsalaan-Alam)
- Fix einsum output axis ordering for free axes (e.g., `i,jk->jki`, `ij,klj->kli`) by correcting final transpose permutation and intermediate left-axis reordering. (@tmattio)
- Add `Nx_io.Cache_dir` module with consolidated cache directory utilities respecting `RAVEN_CACHE_ROOT`, `XDG_CACHE_HOME`, and `HOME` fallback, replacing project-specific cache logic across the whole raven ecosystem (#134, @Arsalaan-Alam)
- Add `Nx_io.save_txt` / `Nx_io.load_txt` with NumPy-compatible formatting, comments, and dtype support (#120, @six-shot)
Expand Down
13 changes: 12 additions & 1 deletion nx/lib/core/frontend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2969,7 +2969,18 @@ module Make (B : Backend_intf.S) = struct
let dim_size = (shape tensor).(working_dim) in
let indices = indices_of_spec dim_size spec in
let tensor' =
if List.length indices = dim_size then tensor
(* If indices cover the whole dimension in identity
order, this is a no-op. Otherwise we must gather
(to support permutations and duplicates). *)
let is_range n indices =
let rec traverse id = function
| [] -> id = n
| hd :: tl when hd = id -> traverse (succ id) tl
| _ -> false
in
traverse 0 indices
in
if is_range dim_size indices then tensor
else if List.length indices = 0 then (
(* Empty slice - create tensor with 0 size in this
dimension *)
Expand Down
14 changes: 14 additions & 0 deletions nx/test/test_nx_indexing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ let test_index_mixed () =
[| 20.; 21.; 22.; 23.; 24.; 30.; 31.; 32.; 33.; 34. |]
indexed

(* Fancy indexing along axis 0 should reorder rows (and support duplicates) *)
let test_index_idx_reorder_rows () =
let t = Nx.create Nx.float32 [| 3; 2 |] (Array.init 6 float_of_int) in
let indexed = Nx.slice [ Nx.L [ 2; 0; 1 ]; Nx.A ] t in
check_t "index idx reorder rows" [| 3; 2 |] [| 4.; 5.; 0.; 1.; 2.; 3. |] indexed

let test_index_idx_duplicate_rows () =
let t = Nx.create Nx.float32 [| 3; 2 |] (Array.init 6 float_of_int) in
let indexed = Nx.slice [ Nx.L [ 1; 1; 0 ]; Nx.A ] t in
check_t "index idx duplicate rows" [| 3; 2 |] [| 2.; 3.; 2.; 3.; 0.; 1. |]
indexed

(* Note: `new_ and `mask require implementation *)
(* let test_index_new_axis () =
let t = Nx.create Nx.float32 [| 3; 4 |] (Array.init 12 float_of_int) in
Expand Down Expand Up @@ -396,6 +408,8 @@ let index_tests =
("index idx", `Quick, test_index_idx);
("index idx repeated", `Quick, test_index_idx_repeated);
("index mixed", `Quick, test_index_mixed);
("index idx reorder rows", `Quick, test_index_idx_reorder_rows);
("index idx duplicate rows", `Quick, test_index_idx_duplicate_rows);
("set_slice at", `Quick, test_set_slice_at);
("set_slice rng", `Quick, test_set_slice_rng);
("set_slice idx", `Quick, test_set_slice_idx);
Expand Down
Loading