Use depth-first Search in _build_data_index#2156
Conversation
…d loop to prevent recursion
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 2.57 % | 4.084e+03 | 4.189e+03 | 105.09 | 32.43 | 38.24 |
test_proximal_jac_w7x_with_eq_update | -0.57 % | 6.536e+03 | 6.499e+03 | -37.02 | 151.74 | 158.04 |
test_proximal_freeb_jac | -0.32 % | 1.340e+04 | 1.335e+04 | -43.30 | 80.11 | 84.51 |
test_proximal_freeb_jac_blocked | -0.01 % | 7.742e+03 | 7.741e+03 | -0.76 | 69.46 | 76.59 |
test_proximal_freeb_jac_batched | -0.89 % | 7.741e+03 | 7.672e+03 | -68.88 | 68.66 | 75.80 |
test_proximal_jac_ripple | -2.44 % | 3.681e+03 | 3.592e+03 | -89.98 | 54.76 | 61.95 |
test_proximal_jac_ripple_bounce1d | 0.40 % | 3.819e+03 | 3.835e+03 | 15.29 | 67.39 | 74.41 |
test_eq_solve | -1.34 % | 2.225e+03 | 2.195e+03 | -29.92 | 88.94 | 96.87 |For the memory plots, go to the summary of |
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +2.48 +/- 3.81 | +2.14e-02 +/- 3.29e-02 | 8.86e-01 +/- 3.1e-02 | 8.65e-01 +/- 1.0e-02 |
test_equilibrium_init_medres | +0.36 +/- 3.87 | +2.49e-02 +/- 2.66e-01 | 6.91e+00 +/- 2.3e-01 | 6.89e+00 +/- 1.3e-01 |
test_equilibrium_init_highres | -1.07 +/- 3.64 | -8.24e-02 +/- 2.81e-01 | 7.64e+00 +/- 2.1e-01 | 7.72e+00 +/- 1.8e-01 |
test_objective_compile_dshape_current | +0.41 +/- 1.64 | +1.78e-02 +/- 7.10e-02 | 4.36e+00 +/- 4.3e-02 | 4.34e+00 +/- 5.7e-02 |
test_objective_compute_dshape_current | +5.05 +/- 19.90 | +3.75e-05 +/- 1.48e-04 | 7.81e-04 +/- 5.3e-05 | 7.43e-04 +/- 1.4e-04 |
test_objective_jac_dshape_current | -0.48 +/- 22.62 | -1.17e-04 +/- 5.49e-03 | 2.42e-02 +/- 3.2e-03 | 2.43e-02 +/- 4.4e-03 |
test_perturb_2 | +4.45 +/- 3.53 | +9.18e-01 +/- 7.28e-01 | 2.16e+01 +/- 6.0e-01 | 2.06e+01 +/- 4.1e-01 |
test_proximal_jac_atf_with_eq_update | +1.01 +/- 1.92 | +1.25e-01 +/- 2.40e-01 | 1.26e+01 +/- 2.0e-01 | 1.25e+01 +/- 1.4e-01 |
test_proximal_freeb_jac | +0.11 +/- 2.55 | +5.20e-03 +/- 1.25e-01 | 4.91e+00 +/- 9.9e-02 | 4.91e+00 +/- 7.6e-02 |
test_solve_fixed_iter_compiled | +0.91 +/- 1.45 | +7.53e-02 +/- 1.20e-01 | 8.37e+00 +/- 1.1e-01 | 8.29e+00 +/- 3.8e-02 |
test_LinearConstraintProjection_build | -1.78 +/- 5.26 | -1.73e-01 +/- 5.12e-01 | 9.56e+00 +/- 3.7e-01 | 9.74e+00 +/- 3.5e-01 |
test_objective_compute_ripple_bounce1d | +0.82 +/- 4.51 | +2.45e-03 +/- 1.36e-02 | 3.04e-01 +/- 9.8e-03 | 3.01e-01 +/- 9.4e-03 |
test_objective_grad_ripple_bounce1d | +0.99 +/- 2.14 | +9.35e-03 +/- 2.02e-02 | 9.55e-01 +/- 1.6e-02 | 9.45e-01 +/- 1.2e-02 |
test_build_transform_fft_midres | -1.65 +/- 3.17 | -1.41e-02 +/- 2.71e-02 | 8.40e-01 +/- 2.0e-02 | 8.54e-01 +/- 1.8e-02 |
test_build_transform_fft_highres | -1.64 +/- 1.80 | -1.78e-02 +/- 1.96e-02 | 1.07e+00 +/- 1.8e-02 | 1.09e+00 +/- 6.4e-03 |
test_equilibrium_init_lowres | -0.93 +/- 2.59 | -5.96e-02 +/- 1.66e-01 | 6.36e+00 +/- 1.2e-01 | 6.42e+00 +/- 1.2e-01 |
test_objective_compile_atf | +0.19 +/- 3.72 | +1.12e-02 +/- 2.23e-01 | 5.99e+00 +/- 1.4e-01 | 5.98e+00 +/- 1.8e-01 |
test_objective_compute_atf | -1.08 +/- 13.41 | -2.62e-05 +/- 3.25e-04 | 2.40e-03 +/- 2.3e-04 | 2.43e-03 +/- 2.3e-04 |
test_objective_jac_atf | +0.99 +/- 5.83 | +1.71e-02 +/- 1.01e-01 | 1.75e+00 +/- 7.5e-02 | 1.74e+00 +/- 6.8e-02 |
test_perturb_1 | +0.29 +/- 2.87 | +4.51e-02 +/- 4.49e-01 | 1.57e+01 +/- 3.8e-01 | 1.56e+01 +/- 2.4e-01 |
test_proximal_jac_atf | -0.06 +/- 1.63 | -3.39e-03 +/- 8.68e-02 | 5.31e+00 +/- 6.4e-02 | 5.32e+00 +/- 5.8e-02 |
test_proximal_freeb_compute | -0.80 +/- 3.58 | -1.17e-03 +/- 5.23e-03 | 1.45e-01 +/- 3.8e-03 | 1.46e-01 +/- 3.6e-03 |
test_solve_fixed_iter | +0.23 +/- 1.99 | +6.67e-02 +/- 5.90e-01 | 2.97e+01 +/- 4.2e-01 | 2.96e+01 +/- 4.2e-01 |
test_objective_compute_ripple | -2.51 +/- 6.26 | -5.43e-03 +/- 1.36e-02 | 2.11e-01 +/- 7.1e-03 | 2.17e-01 +/- 1.2e-02 |
test_objective_grad_ripple | -0.18 +/- 3.32 | -1.73e-03 +/- 3.21e-02 | 9.64e-01 +/- 2.9e-02 | 9.66e-01 +/- 1.3e-02 |Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account. |
…mplify desc.backend jax tests
unalmis
left a comment
There was a problem hiding this comment.
my suggestion on the issue avoids an eigenvalue solve on import. current changes force an eigenvalue solve every time an object is made, which is worse for optimization and debugging with jit in general. i didn't check the claude stuff
| (automorphism_sin, grad_automorphism_sin), | ||
| ), | ||
| ), | ||
| automorphism, |
There was a problem hiding this comment.
- This will now solve an eigenvalue problem every time
Bounce*Dis called. - Autodiff will also trace it into the graph as well now, and JIT is imperfect and does not work well inside scan loops, which is where this stuff is called.
For error analysis, it is vital that the quadrature is not differentiated, and leaving that up to JIT under scan is not something I want given how often bugs are found in it. thus, this change makes it harder to confirm correctness now. - On the other hand, what I suggested in the original issue of just changing the default_quad defined in this file globally avoids an eigenvalue solve on import.
There was a problem hiding this comment.
#2154 (comment)
This was my way of making the default_quad required(ish) argument; if someone wants to avoid this computation, they can pass a precomputed value. If you want to make it actually required, then you can make the change. My point is, there shouldn't be any computation on the import; not everyone uses this function.
There was a problem hiding this comment.
if someone wants to avoid this computation, they can pass a precomputed value.
This is false. Your changes force an eigenvalue solve every time Bounce*D is called
There was a problem hiding this comment.
-
Trivial computations like
jnp.cos(jnp.arange(32))on import are done by the packages we import like numpy etc. For example they load up things like powers of two into cache just to name one. unlike avoiding an eigenvalue solve, avoiding calling a cosine function is probably an unnecessary micro-optimization. -
jnp.cos(jnp.arange(32)) probably consumes comparable time to defining a python function in some file. So it doesn't make sense to me to single that out as a computation. everything is machine code; everything is a computation.
-
Previously, users could import
default_quad, a public variable, into their scripts. They can no longer do that with your change. User code will thus break. -
very simple solution is to just use a the default quad that avoids eigenvalue solve as mentioned in original issue.
There was a problem hiding this comment.
The main cost of leggauss(32) is the jit compiler overhead. The actual calculation takes ~1ms. So one option (if we want to keep using legendre for default quad) could be to just use the original np.polynomial.legendre.leggauss for default_quad, which avoids the jit overhead at init (and since its just a constant it doesn't really need to be jax anyways). switching to chebyshev quadrature avoids the eigenvalue solve but may still have some jit overhead leading to similar issues. (although in either case a pure np version may be faster). I'll defer to kaya on the accuracy impact of legendre vs chebyshev
There was a problem hiding this comment.
accuracy is negligible. does jnp.asarray(numpy computation) have jit overhead? don't want issues like this where jax bakes in the numpy constant:
DESC/desc/integrals/surface_integral.py
Line 285 in 2d9d35a
| get_quadrature( | ||
| leggauss(32), | ||
| (automorphism_sin, grad_automorphism_sin), | ||
| ), |
There was a problem hiding this comment.
| set_device("cpu") | ||
| x = jnp.linspace(0, 5) | ||
| y = jnp.exp(x) | ||
| x = jnp.arange(2) |
There was a problem hiding this comment.
i don't know what the purpose of this change is, but generally you want to do math to test this stuff is working. arange and the like can often execute on host on old jax (< 0.9) versions
There was a problem hiding this comment.
I think to debug any problem with jax, any computation that triggers XLA is enough, at least until now I've never seen otherwise. Pure imports don't show actual problems, but any jnp function does the job. I just made the change to get that check minimally. I don't have a strong opinion about it, but I prefer something that is minimal.
There was a problem hiding this comment.
I'd prefer to keep the old version. As kaya mentioned, older versions of jax will always execute arange on the host meaning GPU issues won't necessarily show up.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2156 +/- ##
==========================================
- Coverage 94.45% 94.40% -0.05%
==========================================
Files 101 101
Lines 28593 28648 +55
==========================================
+ Hits 27008 27046 +38
- Misses 1585 1602 +17
🚀 New features to boost your workflow:
|
| # --- Step 1: Topological sort via iterative Depth-First Search --- | ||
| # We need to process quantities with no dependencies before the | ||
| # quantities that depend on them. This way, when we process key K, | ||
| # all of K's dependencies already have their full_dependencies and | ||
| # full_with_axis_dependencies cached, and we can build K's full | ||
| # dependency set with a simple set union instead of deep recursion. |
There was a problem hiding this comment.
This sounds like it's basically doing what set_tier is doing below, so ideally we could do them both at the same time.
There was a problem hiding this comment.
Yeah, I was actually gonna suggest that in looped compute PR. The order obtained at the end can be used directly. That is why I have special sort instead of normal sort.
Importing
desc.computetakes a very long time, and on my system, only_build_data_indexitself takes 3.5 seconds. This change makes it run in 0.15 seconds!The function now first gets a topological order for all the keys such that the ones that have no dependencies (like R, Z, lambda) are at the beginning, then the ones that just need those, then 2, 3 etc. Once you have such an order, computing all the dependencies for a key is just the union of the
data_index[p][k]["dependencies"]["data"]and their own dependencies (which should have been computed due to this special order). No more recursion is needed.Previous function was slow because of recursion, but it also had an extreme amount of redundant calls. For example,
get_params,get_profiles, andget_derivscallget_depsagain internally, and since the result of the first call is not stored yet, it rebuilds the whole tree again, almost 8 more times.This code was originally written by Claude Code (Opus 4.6), but I wrote most of the comments and I understand the algorithm behind.
Also, removes the
default_quadcomputation at import time.Resolves #2154