|
| 1 | +from functools import partial |
1 | 2 | import jax.numpy as jnp |
2 | 3 | import jax |
3 | 4 | from jax import jvp, vjp |
@@ -81,17 +82,17 @@ def setup_partials(self): |
81 | 82 | def compute(self, inputs, outputs): |
82 | 83 |
|
83 | 84 | # Unpack the options |
84 | | - nodes = self.options['nodes'] |
85 | | - elements = self.options['elements'] |
86 | | - el_size = self.options['el_size'] |
87 | | - el_centers = self.options['el_centers'] |
88 | | - robin_nodes = self.options['robin_nodes'] |
89 | | - robin_h = self.options['robin_h'] |
90 | | - robin_T_inf = self.options['robin_T_inf'] |
91 | | - dirichlet_nodes = self.options['dirichlet_nodes'] |
92 | | - dirichlet_values = self.options['dirichlet_values'] |
93 | | - |
94 | | - robin_area = el_size ** 2 |
| 85 | + nodes = jnp.array(self.options['nodes']) |
| 86 | + elements = jnp.array(self.options['elements']) |
| 87 | + el_size = jnp.array(self.options['el_size']) |
| 88 | + el_centers = jnp.array(self.options['el_centers']) |
| 89 | + robin_nodes = jnp.array(self.options['robin_nodes']) |
| 90 | + robin_h = jnp.array(self.options['robin_h']) |
| 91 | + robin_T_inf = jnp.array(self.options['robin_T_inf']) |
| 92 | + dirichlet_nodes = jnp.array(self.options['dirichlet_nodes']) |
| 93 | + dirichlet_values = jnp.array(self.options['dirichlet_values']) |
| 94 | + |
| 95 | + robin_area = jnp.array(el_size ** 2) |
95 | 96 |
|
96 | 97 | # Unpack the inputs |
97 | 98 | density = jnp.array(inputs["density"]) |
@@ -176,67 +177,62 @@ def fea_solve(rhs): |
176 | 177 | return u, u_max |
177 | 178 |
|
178 | 179 | def compute_jacvec_prod(self, inputs, d_inputs, d_outputs, mode): |
179 | | - """ |
180 | | - Compute the matrix-free Jacobian-vector product for both outputs: |
181 | | - - "temperature" (full field) and |
182 | | - - "max_temperature" (a scalar, e.g., from a Kreisselmeier–Steinhauser function). |
183 | | - """ |
184 | | - # Get current input values. |
| 180 | + |
| 181 | + # Unpack dynamic (differentiable) inputs. |
185 | 182 | density = jnp.array(inputs["density"]) |
186 | 183 | heat_loads = jnp.array(inputs["heat_loads"]) |
187 | 184 |
|
188 | | - # Unpack options. |
189 | | - nodes = self.options['nodes'] |
190 | | - elements = self.options['elements'] |
191 | | - el_size = self.options['el_size'] |
192 | | - robin_nodes = self.options['robin_nodes'] |
193 | | - robin_h = self.options['robin_h'] |
194 | | - robin_T_inf = self.options['robin_T_inf'] |
195 | | - dirichlet_nodes = self.options['dirichlet_nodes'] |
196 | | - dirichlet_values = self.options['dirichlet_values'] |
197 | | - robin_area = el_size ** 2 |
198 | | - |
199 | | - # Prepare the full set of "primal" arguments. |
200 | | - primals = (density, heat_loads, nodes, elements, |
201 | | - robin_nodes, robin_h, robin_T_inf, robin_area, |
202 | | - dirichlet_nodes, dirichlet_values) |
| 185 | + # Unpack static options. |
| 186 | + nodes = jnp.array(self.options['nodes']) |
| 187 | + elements = jnp.array(self.options['elements']) |
| 188 | + el_size = jnp.array(self.options['el_size']) |
| 189 | + robin_nodes = jnp.array(self.options['robin_nodes']) |
| 190 | + robin_h = jnp.array(self.options['robin_h']) |
| 191 | + robin_T_inf = jnp.array(self.options['robin_T_inf']) |
| 192 | + dirichlet_nodes = jnp.array(self.options['dirichlet_nodes']) |
| 193 | + dirichlet_values = jnp.array(self.options['dirichlet_values']) |
| 194 | + robin_area = jnp.array(el_size ** 2) |
| 195 | + |
| 196 | + # Freeze all static arguments via partial so that only density and heat_loads are inputs. |
| 197 | + frozen_compute_primal = partial( |
| 198 | + self._compute_primal, |
| 199 | + nodes=nodes, |
| 200 | + elements=elements, |
| 201 | + r_nodes=robin_nodes, |
| 202 | + r_h=robin_h, |
| 203 | + r_T_inf=robin_T_inf, |
| 204 | + r_area=robin_area, |
| 205 | + d_nodes=dirichlet_nodes, |
| 206 | + d_T=dirichlet_values |
| 207 | + ) |
203 | 208 |
|
204 | 209 | if mode == "fwd": |
| 210 | + # Build tangents only for differentiable inputs. |
| 211 | + tan_density = (jnp.array(d_inputs["density"]) |
| 212 | + if "density" in d_inputs and d_inputs["density"] is not None |
| 213 | + else jnp.zeros_like(density)) |
205 | 214 |
|
206 | | - # Build tangent |
207 | | - tan_density = jnp.array(d_inputs["density"]) if "density" in d_inputs and d_inputs["density"] is not None else jnp.zeros_like(density) |
208 | | - tan_heat_loads = jnp.array(d_inputs["heat_loads"]) if "heat_loads" in d_inputs and d_inputs["heat_loads"] is not None else jnp.zeros_like(heat_loads) |
209 | | - tan_nodes = jnp.zeros_like(nodes) |
210 | | - tan_elements = jnp.zeros_like(elements) |
211 | | - tan_robin_nodes = jnp.zeros_like(robin_nodes) |
212 | | - tan_robin_h = jnp.zeros_like(robin_h) |
213 | | - tan_robin_T_inf = jnp.zeros_like(robin_T_inf) |
214 | | - tan_robin_area = jnp.zeros_like(robin_area) |
215 | | - tan_dirichlet_nodes = jnp.zeros_like(dirichlet_nodes) |
216 | | - tan_dirichlet_values = jnp.zeros_like(dirichlet_values) |
| 215 | + tan_heat_loads = (jnp.array(d_inputs["heat_loads"]) |
| 216 | + if "heat_loads" in d_inputs and d_inputs["heat_loads"] is not None |
| 217 | + else jnp.zeros_like(heat_loads)) |
217 | 218 |
|
218 | | - tangents = (tan_density, tan_heat_loads, tan_nodes, tan_elements, |
219 | | - tan_robin_nodes, tan_robin_h, tan_robin_T_inf, tan_robin_area, |
220 | | - tan_dirichlet_nodes, tan_dirichlet_values) |
| 219 | + primals = (density, heat_loads) |
| 220 | + tangents = (tan_density, tan_heat_loads) |
221 | 221 |
|
222 | | - # Compute forward-mode JVP |
223 | | - _, tangent_out = jvp(self._compute_primal, primals, tangents) |
| 222 | + # Call jax.jvp on the frozen function. |
| 223 | + _, tangent_out = jvp(frozen_compute_primal, primals, tangents) |
224 | 224 |
|
225 | | - d_outputs["temperature"] = tangent_out[0] |
226 | | - d_outputs["max_temperature"] = tangent_out[1] |
| 225 | + # Assign the computed derivatives to the outputs. |
| 226 | + d_outputs["temperature"] += tangent_out[0] |
| 227 | + d_outputs["max_temperature"] += tangent_out[1] |
227 | 228 |
|
228 | 229 | elif mode == "rev": |
229 | | - |
230 | | - _, pullback = vjp(self._compute_primal, *primals) |
231 | | - |
232 | | - cotangent = (d_outputs["temperature"], |
233 | | - d_outputs["max_temperature"]) |
234 | | - |
| 230 | + # Get the VJP (pullback) for the frozen function. |
| 231 | + _, pullback = vjp(frozen_compute_primal, density, heat_loads) |
| 232 | + cotangent = (d_outputs["temperature"], d_outputs["max_temperature"]) |
235 | 233 | grads = pullback(cotangent) |
236 | 234 |
|
237 | | - # Only assign the non-None gradients |
238 | | - # 0 = density, 1 = heat_loads |
239 | | - # 2, 3, ... are constants |
| 235 | + # Accumulate the gradients into d_inputs. |
240 | 236 | d_inputs["density"] += grads[0] |
241 | 237 | d_inputs["heat_loads"] += grads[1] |
242 | 238 |
|
|
0 commit comments