Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ This new release adds support for sparse cost matrices in the exact EMD solver.
#### New features
- Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778)
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix PR number above and add PR numebr here

- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j]

#### Closed issues
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
- Add test for build from source (PR #772, Issue #764)
- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)
Expand Down
29 changes: 24 additions & 5 deletions ot/bregman/_geomloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,32 @@ def get_sinkhorn_geomloss_lazytensor(
shape = (X_a.shape[0], X_b.shape[0])

def func(i, j, X_a, X_b, f, g, a, b, metric, blur):
X_a_i = X_a[i]
X_b_j = X_b[j]

if X_a_i.ndim == 1:
X_a_i = X_a_i[None, :]
if X_b_j.ndim == 1:
X_b_j = X_b_j[None, :]

if metric == "sqeuclidean":
C = dist(X_a[i], X_b[j], metric=metric) / 2
C = dist(X_a_i, X_b_j, metric=metric) / 2
else:
C = dist(X_a[i], X_b[j], metric=metric)
return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (
a[i, None] * b[None, j]
)
C = dist(X_a_i, X_b_j, metric=metric)

# Robust broadcasting using nx backend (handles both numpy and torch)
# For scalars, slice to keep 1D; for arrays, index directly
f_i = f[i : i + 1] if isinstance(i, int) else f[i]
g_j = g[j : j + 1] if isinstance(j, int) else g[j]
a_i = a[i : i + 1] if isinstance(i, int) else a[i]
b_j = b[j : j + 1] if isinstance(j, int) else b[j]

f_i = nx.reshape(f_i, (-1, 1))
g_j = nx.reshape(g_j, (1, -1))
a_i = nx.reshape(a_i, (-1, 1))
b_j = nx.reshape(b_j, (1, -1))

return nx.squeeze(nx.exp((f_i + g_j - C) / (blur**2)) * a_i * b_j)

T = LazyTensor(
shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur
Expand Down
67 changes: 46 additions & 21 deletions ot/lp/sparse_bipartitegraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@ namespace lemon {

mutable std::vector<std::vector<Arc>> _in_arcs; // _in_arcs[node] = incoming arc IDs
mutable bool _in_arcs_built;

// Position tracking for O(1) iteration
mutable std::vector<int64_t> _arc_to_out_pos; // _arc_to_out_pos[arc_id] = position in _arc_ids
mutable std::vector<int64_t> _arc_to_in_pos; // _arc_to_in_pos[arc_id] = position in _in_arcs[target]
mutable bool _position_maps_built;

SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {}
SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false), _position_maps_built(false) {}

void construct(int n1, int n2) {
_node_num = n1 + n2;
Expand All @@ -58,6 +63,9 @@ namespace lemon {
_arc_ids.clear();
_in_arcs.clear();
_in_arcs_built = false;
_arc_to_out_pos.clear();
_arc_to_in_pos.clear();
_position_maps_built = false;
}

void build_in_arcs() const {
Expand All @@ -72,6 +80,31 @@ namespace lemon {

_in_arcs_built = true;
}

void build_position_maps() const {
if (_position_maps_built) return;

_arc_to_out_pos.resize(_arc_num);
_arc_to_in_pos.resize(_arc_num);

// Build outgoing arc position map from CSR structure
for (int64_t pos = 0; pos < _arc_num; ++pos) {
Arc arc_id = _arc_ids[pos];
_arc_to_out_pos[arc_id] = pos;
}

// Build incoming arc position map
build_in_arcs();
for (Node node = 0; node < _node_num; ++node) {
const std::vector<Arc>& in = _in_arcs[node];
for (size_t pos = 0; pos < in.size(); ++pos) {
Arc arc_id = in[pos];
_arc_to_in_pos[arc_id] = pos;
}
}

_position_maps_built = true;
}

public:

Expand Down Expand Up @@ -212,18 +245,14 @@ namespace lemon {

void nextOut(Arc& arc) const {
if (arc < 0) return;


build_position_maps();

int64_t pos = _arc_to_out_pos[arc];
Node src = _arc_sources[arc];
int64_t start = _row_ptr[src];
int64_t end = _row_ptr[src + 1];

for (int64_t i = start; i < end; ++i) {
if (_arc_ids[i] == arc) {
arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1);
return;
}
}
arc = -1;

arc = (pos + 1 < end) ? _arc_ids[pos + 1] : Arc(-1);
}

void firstIn(Arc& arc, const Node& node) const {
Expand All @@ -240,18 +269,14 @@ namespace lemon {

void nextIn(Arc& arc) const {
if (arc < 0) return;


build_position_maps();

int64_t pos = _arc_to_in_pos[arc];
Node tgt = _arc_targets[arc];
const std::vector<Arc>& in = _in_arcs[tgt];

// Find current arc in the list and return next one
for (size_t i = 0; i < in.size(); ++i) {
if (in[i] == arc) {
arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1);
return;
}
}
arc = -1;

arc = (pos + 1 < in.size()) ? in[pos + 1] : Arc(-1);
}
};

Expand Down
Loading