Skip to content

Commit 8975c39

Browse files
committed
Fixing projection
1 parent 41d1157 commit 8975c39

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

examples/MVP/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,6 @@
476476
total_time = (t3 - t1) / 1e9
477477
print(f"Total time: {total_time} seconds")
478478

479-
pf = prob.check_partials(includes='projections.proj_1')
479+
pf = prob.check_partials(includes='FEA')
480480

481481
print('Done')

src/SPI2py/API/FEA.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -202,45 +202,43 @@ def compute_jacvec_prod(self, inputs, d_inputs, d_outputs, mode):
202202
dirichlet_nodes, dirichlet_values)
203203

204204
if mode == "fwd":
205-
# Build tangent (directional) inputs only for the active variables.
206-
tan_density = jnp.array(d_inputs["density"]) if "density" in d_inputs and d_inputs[
207-
"density"] is not None else jnp.zeros_like(density)
208-
tan_heat_loads = jnp.array(d_inputs["heat_loads"]) if "heat_loads" in d_inputs and d_inputs[
209-
"heat_loads"] is not None else jnp.zeros_like(heat_loads)
210-
# For all other parameters, the directional derivative is zero.
205+
206+
# Build tangent
207+
tan_density = jnp.array(d_inputs["density"]) if "density" in d_inputs and d_inputs["density"] is not None else jnp.zeros_like(density)
208+
tan_heat_loads = jnp.array(d_inputs["heat_loads"]) if "heat_loads" in d_inputs and d_inputs["heat_loads"] is not None else jnp.zeros_like(heat_loads)
211209
tan_nodes = jnp.zeros_like(nodes)
212210
tan_elements = jnp.zeros_like(elements)
213211
tan_robin_nodes = jnp.zeros_like(robin_nodes)
214-
tan_robin_h = 0.0
215-
tan_robin_T_inf = 0.0
216-
tan_robin_area = 0.0 # since robin_area is computed from el_size
212+
tan_robin_h = jnp.zeros_like(robin_h)
213+
tan_robin_T_inf = jnp.zeros_like(robin_T_inf)
214+
tan_robin_area = jnp.zeros_like(robin_area)
217215
tan_dirichlet_nodes = jnp.zeros_like(dirichlet_nodes)
218216
tan_dirichlet_values = jnp.zeros_like(dirichlet_values)
217+
219218
tangents = (tan_density, tan_heat_loads, tan_nodes, tan_elements,
220219
tan_robin_nodes, tan_robin_h, tan_robin_T_inf, tan_robin_area,
221220
tan_dirichlet_nodes, tan_dirichlet_values)
222-
# Compute forward-mode JVP. Since _compute_primal returns a tuple (u, u_max),
223-
# tangent_out will be a tuple with the corresponding directional derivatives.
221+
222+
# Compute forward-mode JVP
224223
_, tangent_out = jvp(self._compute_primal, primals, tangents)
224+
225225
d_outputs["temperature"] = tangent_out[0]
226226
d_outputs["max_temperature"] = tangent_out[1]
227227

228228
elif mode == "rev":
229-
# Compute the reverse-mode VJP.
230-
# _compute_primal returns a tuple (u, u_max)
231-
primal_out, pullback = vjp(self._compute_primal, *primals)
232-
cotan_temperature = d_outputs["temperature"] if "temperature" in d_outputs and d_outputs[
233-
"temperature"] is not None else jnp.zeros_like(primal_out[0])
234-
cotan_max_temperature = d_outputs["max_temperature"] if "max_temperature" in d_outputs and d_outputs[
235-
"max_temperature"] is not None else 0.0
236-
# The pullback now expects a tuple of cotangents.
237-
grads = pullback((cotan_temperature, cotan_max_temperature))
238-
# grads is a tuple with gradients for each input argument;
239-
# we only need to accumulate contributions for our optimization variables:
240-
# d_inputs["density"] = grads[0]
241-
# d_inputs["heat_loads"] = grads[1]
242-
d_inputs["density"] = d_inputs.get("density", 0) + grads[0]
243-
d_inputs["heat_loads"] = d_inputs.get("heat_loads", 0) + grads[1]
229+
230+
_, pullback = vjp(self._compute_primal, *primals)
231+
232+
cotangent = (d_outputs["temperature"],
233+
d_outputs["max_temperature"])
234+
235+
grads = pullback(cotangent)
236+
237+
# Only assign the non-None gradients
238+
# 0 = density, 1 = heat_loads
239+
# 2, 3, ... are constants
240+
d_inputs["density"] += grads[0]
241+
d_inputs["heat_loads"] += grads[1]
244242

245243

246244
class BoundaryConditionAggregator(ExplicitComponent):

0 commit comments

Comments
 (0)