@@ -202,45 +202,43 @@ def compute_jacvec_prod(self, inputs, d_inputs, d_outputs, mode):
202202 dirichlet_nodes , dirichlet_values )
203203
204204 if mode == "fwd" :
205- # Build tangent (directional) inputs only for the active variables.
206- tan_density = jnp .array (d_inputs ["density" ]) if "density" in d_inputs and d_inputs [
207- "density" ] is not None else jnp .zeros_like (density )
208- tan_heat_loads = jnp .array (d_inputs ["heat_loads" ]) if "heat_loads" in d_inputs and d_inputs [
209- "heat_loads" ] is not None else jnp .zeros_like (heat_loads )
210- # For all other parameters, the directional derivative is zero.
205+
206+ # Build tangent
207+ tan_density = jnp .array (d_inputs ["density" ]) if "density" in d_inputs and d_inputs ["density" ] is not None else jnp .zeros_like (density )
208+ tan_heat_loads = jnp .array (d_inputs ["heat_loads" ]) if "heat_loads" in d_inputs and d_inputs ["heat_loads" ] is not None else jnp .zeros_like (heat_loads )
211209 tan_nodes = jnp .zeros_like (nodes )
212210 tan_elements = jnp .zeros_like (elements )
213211 tan_robin_nodes = jnp .zeros_like (robin_nodes )
214- tan_robin_h = 0.0
215- tan_robin_T_inf = 0.0
216- tan_robin_area = 0.0 # since robin_area is computed from el_size
212+ tan_robin_h = jnp . zeros_like ( robin_h )
213+ tan_robin_T_inf = jnp . zeros_like ( robin_T_inf )
214+ tan_robin_area = jnp . zeros_like ( robin_area )
217215 tan_dirichlet_nodes = jnp .zeros_like (dirichlet_nodes )
218216 tan_dirichlet_values = jnp .zeros_like (dirichlet_values )
217+
219218 tangents = (tan_density , tan_heat_loads , tan_nodes , tan_elements ,
220219 tan_robin_nodes , tan_robin_h , tan_robin_T_inf , tan_robin_area ,
221220 tan_dirichlet_nodes , tan_dirichlet_values )
222- # Compute forward-mode JVP. Since _compute_primal returns a tuple (u, u_max),
223- # tangent_out will be a tuple with the corresponding directional derivatives.
221+
222+ # Compute forward-mode JVP
224223 _ , tangent_out = jvp (self ._compute_primal , primals , tangents )
224+
225225 d_outputs ["temperature" ] = tangent_out [0 ]
226226 d_outputs ["max_temperature" ] = tangent_out [1 ]
227227
228228 elif mode == "rev" :
229- # Compute the reverse-mode VJP.
230- # _compute_primal returns a tuple (u, u_max)
231- primal_out , pullback = vjp (self ._compute_primal , * primals )
232- cotan_temperature = d_outputs ["temperature" ] if "temperature" in d_outputs and d_outputs [
233- "temperature" ] is not None else jnp .zeros_like (primal_out [0 ])
234- cotan_max_temperature = d_outputs ["max_temperature" ] if "max_temperature" in d_outputs and d_outputs [
235- "max_temperature" ] is not None else 0.0
236- # The pullback now expects a tuple of cotangents.
237- grads = pullback ((cotan_temperature , cotan_max_temperature ))
238- # grads is a tuple with gradients for each input argument;
239- # we only need to accumulate contributions for our optimization variables:
240- # d_inputs["density"] = grads[0]
241- # d_inputs["heat_loads"] = grads[1]
242- d_inputs ["density" ] = d_inputs .get ("density" , 0 ) + grads [0 ]
243- d_inputs ["heat_loads" ] = d_inputs .get ("heat_loads" , 0 ) + grads [1 ]
229+
230+ _ , pullback = vjp (self ._compute_primal , * primals )
231+
232+ cotangent = (d_outputs ["temperature" ],
233+ d_outputs ["max_temperature" ])
234+
235+ grads = pullback (cotangent )
236+
237+ # Only assign the non-None gradients
238+ # 0 = density, 1 = heat_loads
239+ # 2, 3, ... are constants
240+ d_inputs ["density" ] += grads [0 ]
241+ d_inputs ["heat_loads" ] += grads [1 ]
244242
245243
246244class BoundaryConditionAggregator (ExplicitComponent ):
0 commit comments