Skip to content

Minor changes from MPI branch#2153

Open
YigitElma wants to merge 15 commits intomasterfrom
yge/obj_jit
Open

Minor changes from MPI branch#2153
YigitElma wants to merge 15 commits intomasterfrom
yge/obj_jit

Conversation

@YigitElma
Copy link
Copy Markdown
Collaborator

@YigitElma YigitElma commented Apr 9, 2026

For some reason, some of these changes made the LinearConstraintProjection.build() faster on #1495
This PR is to find out why exactly, and also move some of the cosmetic changes from that branch to here.

  • Adds a helper function for all different compute methods (similar to jvp, jac wrappers we have)
  • Adds use_jit_wrapper logic, in case any of the sub-objectives shouldn't be compiled
  • Moves constants assignment to proper place (this was causing a bug in MPI PR, don't remember why exactly though)
  • Adds jit decorator to ObjectiveFunction.unpack_params() (this makes the difference in benchmarks)

@YigitElma YigitElma added run_benchmarks Run timing benchmarks on this PR against current master branch easy Short and simple to code or review labels Apr 9, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 9, 2026

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |   -5.81 %    |     4.172e+03      |     3.929e+03      |   -242.20    |       40.87        |       39.76        |
  test_proximal_jac_w7x_with_eq_update   |   -1.28 %    |     6.584e+03      |     6.500e+03      |    -84.39    |       166.97       |       162.83       |
  test_proximal_freeb_jac                |   -0.24 %    |     1.344e+04      |     1.341e+04      |    -31.68    |       89.17        |       88.56        |
  test_proximal_freeb_jac_blocked        |    0.27 %    |     7.749e+03      |     7.770e+03      |    21.02     |       79.41        |       79.52        |
  test_proximal_freeb_jac_batched        |    0.45 %    |     7.675e+03      |     7.710e+03      |    34.56     |       79.53        |       78.72        |
+ test_proximal_jac_ripple               |   -11.77 %   |     3.623e+03      |     3.196e+03      |   -426.34    |       64.40        |       64.09        |
  test_proximal_jac_ripple_bounce1d      |   -7.53 %    |     3.806e+03      |     3.519e+03      |   -286.64    |       79.07        |       78.86        |
  test_eq_solve                          |   -1.87 %    |     2.181e+03      |     2.141e+03      |    -40.84    |       97.69        |       99.78        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 9, 2026

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +0.67 +/- 2.55     | +4.32e-03 +/- 1.65e-02 |  6.51e-01 +/- 1.4e-02  |  6.46e-01 +/- 8.9e-03  |
 test_equilibrium_init_medres            |     +0.32 +/- 4.00     | +1.71e-02 +/- 2.14e-01 |  5.37e+00 +/- 1.8e-01  |  5.35e+00 +/- 1.2e-01  |
 test_equilibrium_init_highres           |     +0.34 +/- 4.16     | +2.05e-02 +/- 2.51e-01 |  6.04e+00 +/- 2.0e-01  |  6.02e+00 +/- 1.5e-01  |
 test_objective_compile_dshape_current   |     -0.62 +/- 1.37     | -1.99e-02 +/- 4.40e-02 |  3.19e+00 +/- 1.7e-02  |  3.21e+00 +/- 4.0e-02  |
 test_objective_compute_dshape_current   |     -6.42 +/- 6.19     | -3.43e-05 +/- 3.31e-05 |  5.01e-04 +/- 2.5e-05  |  5.35e-04 +/- 2.2e-05  |
 test_objective_jac_dshape_current       |     -7.00 +/- 29.39    | -1.62e-03 +/- 6.80e-03 |  2.15e-02 +/- 5.1e-03  |  2.31e-02 +/- 4.5e-03  |
 test_perturb_2                          |     -1.45 +/- 1.56     | -2.23e-01 +/- 2.39e-01 |  1.51e+01 +/- 1.1e-01  |  1.54e+01 +/- 2.1e-01  |
 test_proximal_jac_atf_with_eq_update    |     +2.68 +/- 0.98     | +2.64e-01 +/- 9.66e-02 |  1.01e+01 +/- 7.4e-02  |  9.84e+00 +/- 6.2e-02  |
 test_proximal_freeb_jac                 |     +0.43 +/- 1.09     | +1.65e-02 +/- 4.20e-02 |  3.85e+00 +/- 2.7e-02  |  3.83e+00 +/- 3.2e-02  |
 test_solve_fixed_iter_compiled          |     -2.67 +/- 1.55     | -1.81e-01 +/- 1.05e-01 |  6.61e+00 +/- 7.6e-02  |  6.79e+00 +/- 7.3e-02  |
+test_LinearConstraintProjection_build   |    -20.17 +/- 4.12     | -1.41e+00 +/- 2.87e-01 |  5.56e+00 +/- 1.5e-01  |  6.97e+00 +/- 2.4e-01  |
 test_objective_compute_ripple_bounce1d  |     -1.32 +/- 7.23     | -3.00e-03 +/- 1.64e-02 |  2.24e-01 +/- 1.2e-02  |  2.27e-01 +/- 1.2e-02  |
 test_objective_grad_ripple_bounce1d     |     +2.46 +/- 5.02     | +1.90e-02 +/- 3.88e-02 |  7.92e-01 +/- 3.3e-02  |  7.73e-01 +/- 2.0e-02  |
 test_build_transform_fft_midres         |     -2.34 +/- 2.66     | -2.07e-02 +/- 2.36e-02 |  8.65e-01 +/- 1.1e-02  |  8.86e-01 +/- 2.1e-02  |
 test_build_transform_fft_highres        |     -0.66 +/- 4.05     | -7.67e-03 +/- 4.69e-02 |  1.15e+00 +/- 1.7e-02  |  1.16e+00 +/- 4.4e-02  |
 test_equilibrium_init_lowres            |     -3.54 +/- 5.02     | -2.37e-01 +/- 3.36e-01 |  6.46e+00 +/- 1.7e-01  |  6.69e+00 +/- 2.9e-01  |
 test_objective_compile_atf              |     -1.69 +/- 3.76     | -1.05e-01 +/- 2.33e-01 |  6.10e+00 +/- 1.7e-01  |  6.21e+00 +/- 1.6e-01  |
 test_objective_compute_atf              |     -9.00 +/- 12.23    | -1.91e-04 +/- 2.60e-04 |  1.93e-03 +/- 2.4e-04  |  2.12e-03 +/- 1.1e-04  |
 test_objective_jac_atf                  |     +0.55 +/- 7.36     | +9.30e-03 +/- 1.24e-01 |  1.70e+00 +/- 8.6e-02  |  1.69e+00 +/- 8.9e-02  |
 test_perturb_1                          |     -0.01 +/- 2.10     | -1.40e-03 +/- 3.27e-01 |  1.56e+01 +/- 2.3e-01  |  1.56e+01 +/- 2.3e-01  |
 test_proximal_jac_atf                   |     +0.56 +/- 2.11     | +3.06e-02 +/- 1.16e-01 |  5.52e+00 +/- 6.8e-02  |  5.49e+00 +/- 9.3e-02  |
 test_proximal_freeb_compute             |     -2.30 +/- 3.19     | -3.23e-03 +/- 4.48e-03 |  1.37e-01 +/- 2.1e-03  |  1.40e-01 +/- 3.9e-03  |
 test_solve_fixed_iter                   |     -5.30 +/- 1.98     | -1.57e+00 +/- 5.87e-01 |  2.81e+01 +/- 4.0e-01  |  2.97e+01 +/- 4.2e-01  |
 test_objective_compute_ripple           |     -0.93 +/- 3.80     | -1.94e-03 +/- 7.90e-03 |  2.06e-01 +/- 4.8e-03  |  2.08e-01 +/- 6.3e-03  |
 test_objective_grad_ripple              |     +0.75 +/- 2.73     | +6.82e-03 +/- 2.49e-02 |  9.20e-01 +/- 2.4e-02  |  9.13e-01 +/- 7.6e-03  |

Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 9, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 94.46%. Comparing base (03637dd) to head (202724c).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2153      +/-   ##
==========================================
+ Coverage   94.45%   94.46%   +0.01%     
==========================================
  Files         101      101              
  Lines       28604    28614      +10     
==========================================
+ Hits        27018    27031      +13     
+ Misses       1586     1583       -3     
Files with missing lines Coverage Δ
desc/objectives/objective_funs.py 94.98% <100.00%> (+0.13%) ⬆️
desc/optimize/_constraint_wrappers.py 97.41% <100.00%> (+0.06%) ⬆️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@YigitElma
Copy link
Copy Markdown
Collaborator Author

use_jit_wrapper allows the following:

eq = get("precise_QA")

class DummyObj(_Objective):
    def __init__(self, eq, grid=None):
        if grid is None:
            grid = LinearGrid(M=eq.M, N=eq.N, NFP=eq.NFP)
        self._grid = grid
        super().__init__(things=eq)

    def build(self, use_jit, verbose=3):
        print("build called")
        self._dim_f = 1
        self.transforms = get_transforms(["|B|"], self.things[0], self._grid)
        self.profiles = get_profiles(["|B|"], self.things[0], self._grid)
        super().build(use_jit=use_jit, verbose=verbose)

    def compute(self, params, constants=None):
        print("fun called")
        data = compute_fun(
            self.things[0],
            ["|B|"],
            params=params,
            transforms=self.transforms,
            profiles=self.profiles,
        )
        f = data["|B|"].max()
        if f > 0:
            f += data["|B|"].min()
            print(f"{f}")
        return f


obj1 = DummyObj(eq)
obj1.build(use_jit=False)
obj = ObjectiveFunction((ForceBalance(eq), obj1))
obj.build(use_jit=True)
obj.compute_scalar(obj.x(eq))
obj.grad(obj.x(eq))
build called
Building objective: force
Precomputing transforms
fun called
2.0040131258026577
fun called
LinearizeTracer<float64[]>
...

Comment thread desc/objectives/objective_funs.py
@YigitElma
Copy link
Copy Markdown
Collaborator Author

Might need to unjit helpers of ProximalProjection to work properly use_jit_wrapper=False case

@YigitElma YigitElma self-assigned this Apr 9, 2026
@YigitElma
Copy link
Copy Markdown
Collaborator Author

Add a warning for use_jit_wrapper

Comment thread desc/objectives/objective_funs.py
@YigitElma YigitElma marked this pull request as ready for review April 10, 2026 01:14
@YigitElma YigitElma marked this pull request as draft April 10, 2026 03:10
@YigitElma YigitElma marked this pull request as ready for review April 10, 2026 04:55
@YigitElma YigitElma requested review from ddudt, dpanici and f0uriest April 12, 2026 21:05
Copy link
Copy Markdown
Collaborator

@ddudt ddudt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments/suggestions, but otherwise this looks good.

Comment thread CHANGELOG.md Outdated
Comment thread desc/objectives/objective_funs.py Outdated
Comment on lines +4331 to +4333
assert f == 2 * r**2 if r > 0 else 9 * r**2 / 2
assert g[0] == 4 * r if r > 0 else 9 * r
assert np.all(g[1:] == 0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just checking that the objective values and derivatives are correct when it is not JITed? Doesn't feel necessary but doesn't hurt.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just a sanity check that jittability is not necessary for AD

YigitElma and others added 2 commits April 13, 2026 13:05
Co-authored-by: Daniel Dudt <33005725+ddudt@users.noreply.github.com>
@YigitElma YigitElma requested a review from ddudt April 14, 2026 04:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

easy Short and simple to code or review run_benchmarks Run timing benchmarks on this PR against current master branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants