Skip to content

Commit 49873a5

Browse files
Merge pull request #930 from SciML/sb/docs_fix
docs: fix pino tutorial
2 parents 2dd77fa + 131e257 commit 49873a5

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

docs/src/tutorials/pino_ode.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,28 @@ u0 = 1.0
2121
prob = ODEProblem(equation, u0, tspan)
2222
2323
# Set the number of parameters for the ODE
24-
number_of_parameter = 3
24+
num_params = 3
25+
2526
# Define the DeepONet architecture for the PINO
2627
deeponet = NeuralOperators.DeepONet(
2728
Chain(
28-
Dense(number_of_parameter => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
29+
Dense(num_params => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
2930
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
3031
Dense(10 => 10, Lux.tanh_fast)))
3132
3233
# Define the bounds for the parameters
3334
bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)]
3435
number_of_parameter_samples = 50
36+
3537
# Define the training strategy
3638
strategy = StochasticTraining(20)
39+
3740
# Define the optimizer
3841
opt = OptimizationOptimisers.Adam(0.03)
39-
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
42+
43+
# Define `PINNODE`
44+
alg = PINOODE(deeponet, opt, bounds, num_params; strategy = strategy)
45+
4046
# Solve the ODE problem using the PINOODE algorithm
4147
sol = solve(prob, alg, verbose = false, maxiters = 4000)
4248
```
@@ -63,21 +69,25 @@ end
6369
6470
# generate the solution with new parameters for test the model
6571
(p, t) = get_trainset(bounds, tspan, 50, 0.025)
72+
6673
# compute the ground truth solution
6774
ground_solution_ = ground_solution_f(p, t)
75+
6876
# predict the solution with the PINO model
69-
predict = sol.interp((p, t))
77+
predict = sol.interp(p, t)
7078
7179
# calculate the errors between the ground truth solution and the predicted solution
7280
errors = ground_solution_ - predict
81+
7382
# calculate the mean error and the standard deviation of the errors
7483
mean_error = mean(errors)
84+
7585
# calculate the standard deviation of the errors
7686
std_error = std(errors)
7787
7888
p, t = get_trainset(bounds, tspan, 100, 0.01)
7989
ground_solution_ = ground_solution_f(p, t)
80-
predict = sol.interp((p, t))
90+
predict = sol.interp(p, t)
8191
8292
errors = ground_solution_ - predict
8393
mean_error = mean(errors)

0 commit comments

Comments
 (0)