@@ -21,22 +21,28 @@ u0 = 1.0
2121prob = ODEProblem(equation, u0, tspan)
2222
2323# Set the number of parameters for the ODE
24- number_of_parameter = 3
24+ num_params = 3
25+
2526# Define the DeepONet architecture for the PINO
2627deeponet = NeuralOperators.DeepONet(
2728 Chain(
28- Dense(number_of_parameter => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
29+ Dense(num_params => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
2930 Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
3031 Dense(10 => 10, Lux.tanh_fast)))
3132
3233# Define the bounds for the parameters
3334bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)]
3435number_of_parameter_samples = 50
36+
3537# Define the training strategy
3638strategy = StochasticTraining(20)
39+
3740# Define the optimizer
3841opt = OptimizationOptimisers.Adam(0.03)
39- alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
42+
43+ # Define `PINNODE`
44+ alg = PINOODE(deeponet, opt, bounds, num_params; strategy = strategy)
45+
4046# Solve the ODE problem using the PINOODE algorithm
4147sol = solve(prob, alg, verbose = false, maxiters = 4000)
4248```
6369
6470# generate the solution with new parameters for test the model
6571(p, t) = get_trainset(bounds, tspan, 50, 0.025)
72+
6673# compute the ground truth solution
6774ground_solution_ = ground_solution_f(p, t)
75+
6876# predict the solution with the PINO model
69- predict = sol.interp(( p, t) )
77+ predict = sol.interp(p, t)
7078
7179# calculate the errors between the ground truth solution and the predicted solution
7280errors = ground_solution_ - predict
81+
7382# calculate the mean error and the standard deviation of the errors
7483mean_error = mean(errors)
84+
7585# calculate the standard deviation of the errors
7686std_error = std(errors)
7787
7888p, t = get_trainset(bounds, tspan, 100, 0.01)
7989ground_solution_ = ground_solution_f(p, t)
80- predict = sol.interp(( p, t) )
90+ predict = sol.interp(p, t)
8191
8292errors = ground_solution_ - predict
8393mean_error = mean(errors)
0 commit comments