Skip to content

Commit 439a73f

Browse files
committed
Fixing projection
1 parent 8975c39 commit 439a73f

File tree

2 files changed

+65
-60
lines changed

2 files changed

+65
-60
lines changed

examples/MVP/workflow.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,15 @@
476476
total_time = (t3 - t1) / 1e9
477477
print(f"Total time: {total_time} seconds")
478478

479-
pf = prob.check_partials(includes='FEA')
479+
# pf = prob.check_partials(includes='FEA')
480+
# tot = prob.compute_totals(of=['FEA.max_temperature'], wrt=['system.components.comp_2.translation'])
481+
# tot = tot[('FEA.max_temperature', 'system.components.comp_2.translation')][0]
482+
# print("total", tot)
483+
d_inputs = {'density': jnp.ones_like(densities_combined), 'heat_loads': jnp.ones_like(densities_combined)}
484+
d_outputs = {'temperature': jnp.ones_like(T), 'max_temperature': jnp.array(1)}
485+
486+
jvp_vals = prob.model.FEA.compute_jacvec_prod(prob.model.FEA._inputs, d_inputs, d_outputs, mode='fwd')
487+
# vjp_vals = prob.model.FEA.compute_jacvec_prod(prob.model.FEA._inputs, d_inputs, d_outputs, mode='rev')
488+
480489

481490
print('Done')

src/SPI2py/API/FEA.py

Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
import jax.numpy as jnp
23
import jax
34
from jax import jvp, vjp
@@ -81,17 +82,17 @@ def setup_partials(self):
8182
def compute(self, inputs, outputs):
8283

8384
# Unpack the options
84-
nodes = self.options['nodes']
85-
elements = self.options['elements']
86-
el_size = self.options['el_size']
87-
el_centers = self.options['el_centers']
88-
robin_nodes = self.options['robin_nodes']
89-
robin_h = self.options['robin_h']
90-
robin_T_inf = self.options['robin_T_inf']
91-
dirichlet_nodes = self.options['dirichlet_nodes']
92-
dirichlet_values = self.options['dirichlet_values']
93-
94-
robin_area = el_size ** 2
85+
nodes = jnp.array(self.options['nodes'])
86+
elements = jnp.array(self.options['elements'])
87+
el_size = jnp.array(self.options['el_size'])
88+
el_centers = jnp.array(self.options['el_centers'])
89+
robin_nodes = jnp.array(self.options['robin_nodes'])
90+
robin_h = jnp.array(self.options['robin_h'])
91+
robin_T_inf = jnp.array(self.options['robin_T_inf'])
92+
dirichlet_nodes = jnp.array(self.options['dirichlet_nodes'])
93+
dirichlet_values = jnp.array(self.options['dirichlet_values'])
94+
95+
robin_area = jnp.array(el_size ** 2)
9596

9697
# Unpack the inputs
9798
density = jnp.array(inputs["density"])
@@ -176,67 +177,62 @@ def fea_solve(rhs):
176177
return u, u_max
177178

178179
def compute_jacvec_prod(self, inputs, d_inputs, d_outputs, mode):
179-
"""
180-
Compute the matrix-free Jacobian-vector product for both outputs:
181-
- "temperature" (full field) and
182-
- "max_temperature" (a scalar, e.g., from a Kreisselmeier–Steinhauser function).
183-
"""
184-
# Get current input values.
180+
181+
# Unpack dynamic (differentiable) inputs.
185182
density = jnp.array(inputs["density"])
186183
heat_loads = jnp.array(inputs["heat_loads"])
187184

188-
# Unpack options.
189-
nodes = self.options['nodes']
190-
elements = self.options['elements']
191-
el_size = self.options['el_size']
192-
robin_nodes = self.options['robin_nodes']
193-
robin_h = self.options['robin_h']
194-
robin_T_inf = self.options['robin_T_inf']
195-
dirichlet_nodes = self.options['dirichlet_nodes']
196-
dirichlet_values = self.options['dirichlet_values']
197-
robin_area = el_size ** 2
198-
199-
# Prepare the full set of "primal" arguments.
200-
primals = (density, heat_loads, nodes, elements,
201-
robin_nodes, robin_h, robin_T_inf, robin_area,
202-
dirichlet_nodes, dirichlet_values)
185+
# Unpack static options.
186+
nodes = jnp.array(self.options['nodes'])
187+
elements = jnp.array(self.options['elements'])
188+
el_size = jnp.array(self.options['el_size'])
189+
robin_nodes = jnp.array(self.options['robin_nodes'])
190+
robin_h = jnp.array(self.options['robin_h'])
191+
robin_T_inf = jnp.array(self.options['robin_T_inf'])
192+
dirichlet_nodes = jnp.array(self.options['dirichlet_nodes'])
193+
dirichlet_values = jnp.array(self.options['dirichlet_values'])
194+
robin_area = jnp.array(el_size ** 2)
195+
196+
# Freeze all static arguments via partial so that only density and heat_loads are inputs.
197+
frozen_compute_primal = partial(
198+
self._compute_primal,
199+
nodes=nodes,
200+
elements=elements,
201+
r_nodes=robin_nodes,
202+
r_h=robin_h,
203+
r_T_inf=robin_T_inf,
204+
r_area=robin_area,
205+
d_nodes=dirichlet_nodes,
206+
d_T=dirichlet_values
207+
)
203208

204209
if mode == "fwd":
210+
# Build tangents only for differentiable inputs.
211+
tan_density = (jnp.array(d_inputs["density"])
212+
if "density" in d_inputs and d_inputs["density"] is not None
213+
else jnp.zeros_like(density))
205214

206-
# Build tangent
207-
tan_density = jnp.array(d_inputs["density"]) if "density" in d_inputs and d_inputs["density"] is not None else jnp.zeros_like(density)
208-
tan_heat_loads = jnp.array(d_inputs["heat_loads"]) if "heat_loads" in d_inputs and d_inputs["heat_loads"] is not None else jnp.zeros_like(heat_loads)
209-
tan_nodes = jnp.zeros_like(nodes)
210-
tan_elements = jnp.zeros_like(elements)
211-
tan_robin_nodes = jnp.zeros_like(robin_nodes)
212-
tan_robin_h = jnp.zeros_like(robin_h)
213-
tan_robin_T_inf = jnp.zeros_like(robin_T_inf)
214-
tan_robin_area = jnp.zeros_like(robin_area)
215-
tan_dirichlet_nodes = jnp.zeros_like(dirichlet_nodes)
216-
tan_dirichlet_values = jnp.zeros_like(dirichlet_values)
215+
tan_heat_loads = (jnp.array(d_inputs["heat_loads"])
216+
if "heat_loads" in d_inputs and d_inputs["heat_loads"] is not None
217+
else jnp.zeros_like(heat_loads))
217218

218-
tangents = (tan_density, tan_heat_loads, tan_nodes, tan_elements,
219-
tan_robin_nodes, tan_robin_h, tan_robin_T_inf, tan_robin_area,
220-
tan_dirichlet_nodes, tan_dirichlet_values)
219+
primals = (density, heat_loads)
220+
tangents = (tan_density, tan_heat_loads)
221221

222-
# Compute forward-mode JVP
223-
_, tangent_out = jvp(self._compute_primal, primals, tangents)
222+
# Call jax.jvp on the frozen function.
223+
_, tangent_out = jvp(frozen_compute_primal, primals, tangents)
224224

225-
d_outputs["temperature"] = tangent_out[0]
226-
d_outputs["max_temperature"] = tangent_out[1]
225+
# Assign the computed derivatives to the outputs.
226+
d_outputs["temperature"] += tangent_out[0]
227+
d_outputs["max_temperature"] += tangent_out[1]
227228

228229
elif mode == "rev":
229-
230-
_, pullback = vjp(self._compute_primal, *primals)
231-
232-
cotangent = (d_outputs["temperature"],
233-
d_outputs["max_temperature"])
234-
230+
# Get the VJP (pullback) for the frozen function.
231+
_, pullback = vjp(frozen_compute_primal, density, heat_loads)
232+
cotangent = (d_outputs["temperature"], d_outputs["max_temperature"])
235233
grads = pullback(cotangent)
236234

237-
# Only assign the non-None gradients
238-
# 0 = density, 1 = heat_loads
239-
# 2, 3, ... are constants
235+
# Accumulate the gradients into d_inputs.
240236
d_inputs["density"] += grads[0]
241237
d_inputs["heat_loads"] += grads[1]
242238

0 commit comments

Comments
 (0)