Skip to content

Commit 9147965

Browse files
committed
Fixing projection
1 parent 439a73f commit 9147965

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

examples/MVP/workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,14 +476,14 @@
476476
total_time = (t3 - t1) / 1e9
477477
print(f"Total time: {total_time} seconds")
478478

479-
# pf = prob.check_partials(includes='FEA')
479+
pf = prob.check_partials(includes='FEA')
480480
# tot = prob.compute_totals(of=['FEA.max_temperature'], wrt=['system.components.comp_2.translation'])
481481
# tot = tot[('FEA.max_temperature', 'system.components.comp_2.translation')][0]
482482
# print("total", tot)
483-
d_inputs = {'density': jnp.ones_like(densities_combined), 'heat_loads': jnp.ones_like(densities_combined)}
484-
d_outputs = {'temperature': jnp.ones_like(T), 'max_temperature': jnp.array(1)}
483+
# d_inputs = {'density': jnp.ones_like(densities_combined), 'heat_loads': jnp.ones_like(densities_combined)}
484+
# d_outputs = {'temperature': jnp.ones_like(T), 'max_temperature': jnp.array(1)}
485485

486-
jvp_vals = prob.model.FEA.compute_jacvec_prod(prob.model.FEA._inputs, d_inputs, d_outputs, mode='fwd')
486+
# jvp_vals = prob.model.FEA.compute_jacvec_prod(prob.model.FEA._inputs, d_inputs, d_outputs, mode='fwd')
487487
# vjp_vals = prob.model.FEA.compute_jacvec_prod(prob.model.FEA._inputs, d_inputs, d_outputs, mode='rev')
488488

489489

src/SPI2py/API/FEA.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def fea_solve(rhs):
176176

177177
return u, u_max
178178

179-
def compute_jacvec_prod(self, inputs, d_inputs, d_outputs, mode):
179+
def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
180180

181181
# Unpack dynamic (differentiable) inputs.
182182
density = jnp.array(inputs["density"])
@@ -206,6 +206,7 @@ def compute_jacvec_prod(self, inputs, d_inputs, d_outputs, mode):
206206
d_T=dirichlet_values
207207
)
208208

209+
209210
if mode == "fwd":
210211
# Build tangents only for differentiable inputs.
211212
tan_density = (jnp.array(d_inputs["density"])

0 commit comments

Comments
 (0)