Skip to content

Use depth-first Search in _build_data_index#2156

Draft
YigitElma wants to merge 6 commits intomasterfrom
yge/build_data_index
Draft

Use depth-first Search in _build_data_index#2156
YigitElma wants to merge 6 commits intomasterfrom
yge/build_data_index

Conversation

@YigitElma
Copy link
Copy Markdown
Collaborator

@YigitElma YigitElma commented Apr 9, 2026

Importing desc.compute takes a very long time, and on my system, only _build_data_index itself takes 3.5 seconds. This change makes it run in 0.15 seconds!

The function now first gets a topological order for all the keys such that the ones that have no dependencies (like R, Z, lambda) are at the beginning, then the ones that just need those, then 2, 3 etc. Once you have such an order, computing all the dependencies for a key is just the union of the data_index[p][k]["dependencies"]["data"] and their own dependencies (which should have been computed due to this special order). No more recursion is needed.

Previous function was slow because of recursion, but it also had an extreme amount of redundant calls. For example, get_params, get_profiles, and get_derivs call get_deps again internally, and since the result of the first call is not stored yet, it rebuilds the whole tree again, almost 8 more times.

This code was originally written by Claude Code (Opus 4.6), but I wrote most of the comments and I understand the algorithm behind.

Also, removes the default_quad computation at import time.

Resolves #2154

@YigitElma YigitElma self-assigned this Apr 9, 2026
@YigitElma YigitElma added the run_benchmarks Run timing benchmarks on this PR against current master branch label Apr 9, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 10, 2026

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    2.57 %    |     4.084e+03      |     4.189e+03      |    105.09    |       32.43        |       38.24        |
  test_proximal_jac_w7x_with_eq_update   |   -0.57 %    |     6.536e+03      |     6.499e+03      |    -37.02    |       151.74       |       158.04       |
  test_proximal_freeb_jac                |   -0.32 %    |     1.340e+04      |     1.335e+04      |    -43.30    |       80.11        |       84.51        |
  test_proximal_freeb_jac_blocked        |   -0.01 %    |     7.742e+03      |     7.741e+03      |    -0.76     |       69.46        |       76.59        |
  test_proximal_freeb_jac_batched        |   -0.89 %    |     7.741e+03      |     7.672e+03      |    -68.88    |       68.66        |       75.80        |
  test_proximal_jac_ripple               |   -2.44 %    |     3.681e+03      |     3.592e+03      |    -89.98    |       54.76        |       61.95        |
  test_proximal_jac_ripple_bounce1d      |    0.40 %    |     3.819e+03      |     3.835e+03      |    15.29     |       67.39        |       74.41        |
  test_eq_solve                          |   -1.34 %    |     2.225e+03      |     2.195e+03      |    -29.92    |       88.94        |       96.87        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 10, 2026

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +2.48 +/- 3.81     | +2.14e-02 +/- 3.29e-02 |  8.86e-01 +/- 3.1e-02  |  8.65e-01 +/- 1.0e-02  |
 test_equilibrium_init_medres            |     +0.36 +/- 3.87     | +2.49e-02 +/- 2.66e-01 |  6.91e+00 +/- 2.3e-01  |  6.89e+00 +/- 1.3e-01  |
 test_equilibrium_init_highres           |     -1.07 +/- 3.64     | -8.24e-02 +/- 2.81e-01 |  7.64e+00 +/- 2.1e-01  |  7.72e+00 +/- 1.8e-01  |
 test_objective_compile_dshape_current   |     +0.41 +/- 1.64     | +1.78e-02 +/- 7.10e-02 |  4.36e+00 +/- 4.3e-02  |  4.34e+00 +/- 5.7e-02  |
 test_objective_compute_dshape_current   |     +5.05 +/- 19.90    | +3.75e-05 +/- 1.48e-04 |  7.81e-04 +/- 5.3e-05  |  7.43e-04 +/- 1.4e-04  |
 test_objective_jac_dshape_current       |     -0.48 +/- 22.62    | -1.17e-04 +/- 5.49e-03 |  2.42e-02 +/- 3.2e-03  |  2.43e-02 +/- 4.4e-03  |
 test_perturb_2                          |     +4.45 +/- 3.53     | +9.18e-01 +/- 7.28e-01 |  2.16e+01 +/- 6.0e-01  |  2.06e+01 +/- 4.1e-01  |
 test_proximal_jac_atf_with_eq_update    |     +1.01 +/- 1.92     | +1.25e-01 +/- 2.40e-01 |  1.26e+01 +/- 2.0e-01  |  1.25e+01 +/- 1.4e-01  |
 test_proximal_freeb_jac                 |     +0.11 +/- 2.55     | +5.20e-03 +/- 1.25e-01 |  4.91e+00 +/- 9.9e-02  |  4.91e+00 +/- 7.6e-02  |
 test_solve_fixed_iter_compiled          |     +0.91 +/- 1.45     | +7.53e-02 +/- 1.20e-01 |  8.37e+00 +/- 1.1e-01  |  8.29e+00 +/- 3.8e-02  |
 test_LinearConstraintProjection_build   |     -1.78 +/- 5.26     | -1.73e-01 +/- 5.12e-01 |  9.56e+00 +/- 3.7e-01  |  9.74e+00 +/- 3.5e-01  |
 test_objective_compute_ripple_bounce1d  |     +0.82 +/- 4.51     | +2.45e-03 +/- 1.36e-02 |  3.04e-01 +/- 9.8e-03  |  3.01e-01 +/- 9.4e-03  |
 test_objective_grad_ripple_bounce1d     |     +0.99 +/- 2.14     | +9.35e-03 +/- 2.02e-02 |  9.55e-01 +/- 1.6e-02  |  9.45e-01 +/- 1.2e-02  |
 test_build_transform_fft_midres         |     -1.65 +/- 3.17     | -1.41e-02 +/- 2.71e-02 |  8.40e-01 +/- 2.0e-02  |  8.54e-01 +/- 1.8e-02  |
 test_build_transform_fft_highres        |     -1.64 +/- 1.80     | -1.78e-02 +/- 1.96e-02 |  1.07e+00 +/- 1.8e-02  |  1.09e+00 +/- 6.4e-03  |
 test_equilibrium_init_lowres            |     -0.93 +/- 2.59     | -5.96e-02 +/- 1.66e-01 |  6.36e+00 +/- 1.2e-01  |  6.42e+00 +/- 1.2e-01  |
 test_objective_compile_atf              |     +0.19 +/- 3.72     | +1.12e-02 +/- 2.23e-01 |  5.99e+00 +/- 1.4e-01  |  5.98e+00 +/- 1.8e-01  |
 test_objective_compute_atf              |     -1.08 +/- 13.41    | -2.62e-05 +/- 3.25e-04 |  2.40e-03 +/- 2.3e-04  |  2.43e-03 +/- 2.3e-04  |
 test_objective_jac_atf                  |     +0.99 +/- 5.83     | +1.71e-02 +/- 1.01e-01 |  1.75e+00 +/- 7.5e-02  |  1.74e+00 +/- 6.8e-02  |
 test_perturb_1                          |     +0.29 +/- 2.87     | +4.51e-02 +/- 4.49e-01 |  1.57e+01 +/- 3.8e-01  |  1.56e+01 +/- 2.4e-01  |
 test_proximal_jac_atf                   |     -0.06 +/- 1.63     | -3.39e-03 +/- 8.68e-02 |  5.31e+00 +/- 6.4e-02  |  5.32e+00 +/- 5.8e-02  |
 test_proximal_freeb_compute             |     -0.80 +/- 3.58     | -1.17e-03 +/- 5.23e-03 |  1.45e-01 +/- 3.8e-03  |  1.46e-01 +/- 3.6e-03  |
 test_solve_fixed_iter                   |     +0.23 +/- 1.99     | +6.67e-02 +/- 5.90e-01 |  2.97e+01 +/- 4.2e-01  |  2.96e+01 +/- 4.2e-01  |
 test_objective_compute_ripple           |     -2.51 +/- 6.26     | -5.43e-03 +/- 1.36e-02 |  2.11e-01 +/- 7.1e-03  |  2.17e-01 +/- 1.2e-02  |
 test_objective_grad_ripple              |     -0.18 +/- 3.32     | -1.73e-03 +/- 3.21e-02 |  9.64e-01 +/- 2.9e-02  |  9.66e-01 +/- 1.3e-02  |

Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account.

@YigitElma YigitElma marked this pull request as draft April 10, 2026 03:10
Copy link
Copy Markdown
Collaborator

@unalmis unalmis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suggestion on the issue avoids an eigenvalue solve on import. current changes force an eigenvalue solve every time an object is made, which is worse for optimization and debugging with jit in general. i didn't check the claude stuff

(automorphism_sin, grad_automorphism_sin),
),
),
automorphism,
Copy link
Copy Markdown
Collaborator

@unalmis unalmis Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This will now solve an eigenvalue problem every time Bounce*D is called.
  • Autodiff will also trace it into the graph as well now, and JIT is imperfect and does not work well inside scan loops, which is where this stuff is called.
    For error analysis, it is vital that the quadrature is not differentiated, and leaving that up to JIT under scan is not something I want given how often bugs are found in it. thus, this change makes it harder to confirm correctness now.
  • On the other hand, what I suggested in the original issue of just changing the default_quad defined in this file globally avoids an eigenvalue solve on import.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2154 (comment)
This was my way of making the default_quad required(ish) argument; if someone wants to avoid this computation, they can pass a precomputed value. If you want to make it actually required, then you can make the change. My point is, there shouldn't be any computation on the import; not everyone uses this function.

Copy link
Copy Markdown
Collaborator

@unalmis unalmis Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if someone wants to avoid this computation, they can pass a precomputed value.

This is false. Your changes force an eigenvalue solve every time Bounce*D is called

Copy link
Copy Markdown
Collaborator

@unalmis unalmis Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Trivial computations like jnp.cos(jnp.arange(32)) on import are done by the packages we import like numpy etc. For example they load up things like powers of two into cache just to name one. unlike avoiding an eigenvalue solve, avoiding calling a cosine function is probably an unnecessary micro-optimization.

  • jnp.cos(jnp.arange(32)) probably consumes comparable time to defining a python function in some file. So it doesn't make sense to me to single that out as a computation. everything is machine code; everything is a computation.

  • Previously, users could import default_quad, a public variable, into their scripts. They can no longer do that with your change. User code will thus break.

  • very simple solution is to just use a the default quad that avoids eigenvalue solve as mentioned in original issue.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main cost of leggauss(32) is the jit compiler overhead. The actual calculation takes ~1ms. So one option (if we want to keep using legendre for default quad) could be to just use the original np.polynomial.legendre.leggauss for default_quad, which avoids the jit overhead at init (and since its just a constant it doesn't really need to be jax anyways). switching to chebyshev quadrature avoids the eigenvalue solve but may still have some jit overhead leading to similar issues. (although in either case a pure np version may be faster). I'll defer to kaya on the accuracy impact of legendre vs chebyshev

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accuracy is negligible. does jnp.asarray(numpy computation) have jit overhead? don't want issues like this where jax bakes in the numpy constant:

# reduces memory usage by > 400% for the forward computation and Jacobian.

get_quadrature(
leggauss(32),
(automorphism_sin, grad_automorphism_sin),
),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_device("cpu")
x = jnp.linspace(0, 5)
y = jnp.exp(x)
x = jnp.arange(2)
Copy link
Copy Markdown
Collaborator

@unalmis unalmis Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't know what the purpose of this change is, but generally you want to do math to test this stuff is working. arange and the like can often execute on host on old jax (< 0.9) versions

Copy link
Copy Markdown
Collaborator Author

@YigitElma YigitElma Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to debug any problem with jax, any computation that triggers XLA is enough, at least until now I've never seen otherwise. Pure imports don't show actual problems, but any jnp function does the job. I just made the change to get that check minimally. I don't have a strong opinion about it, but I prefer something that is minimal.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to keep the old version. As kaya mentioned, older versions of jax will always execute arange on the host meaning GPU issues won't necessarily show up.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 10, 2026

Codecov Report

❌ Patch coverage is 97.10145% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 94.40%. Comparing base (2d9d35a) to head (f05ffdc).

Files with missing lines Patch % Lines
desc/backend.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2156      +/-   ##
==========================================
- Coverage   94.45%   94.40%   -0.05%     
==========================================
  Files         101      101              
  Lines       28593    28648      +55     
==========================================
+ Hits        27008    27046      +38     
- Misses       1585     1602      +17     
Files with missing lines Coverage Δ
desc/compute/__init__.py 100.00% <100.00%> (ø)
desc/integrals/bounce_integral.py 97.95% <100.00%> (-0.01%) ⬇️
desc/backend.py 90.15% <33.33%> (+0.41%) ⬆️

... and 3 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines +105 to +110
# --- Step 1: Topological sort via iterative Depth-First Search ---
# We need to process quantities with no dependencies before the
# quantities that depend on them. This way, when we process key K,
# all of K's dependencies already have their full_dependencies and
# full_with_axis_dependencies cached, and we can build K's full
# dependency set with a simple set union instead of deep recursion.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like it's basically doing what set_tier is doing below, so ideally we could do them both at the same time.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was actually gonna suggest that in looped compute PR. The order obtained at the end can be used directly. That is why I have special sort instead of normal sort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run_benchmarks Run timing benchmarks on this PR against current master branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Import time is too long

3 participants