Skip to content

Commit 14dab14

Browse files
Final GSOC 2025 commit.
1 parent 24ee4d8 commit 14dab14

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/NN_SDE_solve.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ function generate_DataMoments_loss(
363363

364364
# moment matching (MSE across time for 1st, 2nd moments) - assumes diffusion is a Gaussian at each timepoint
365365
# uses sample variance
366-
return (θ, _) -> begin
366+
return (θ,
367+
_) -> begin
367368
sum(abs2,
368369
mean(process, dims = 2) .-
369370
mean.(Base.Fix2(phi, θ).(sdephi_inputs))) / length(ts) +
@@ -406,7 +407,8 @@ function generate_EM_L2loss(dataset::Vector{<:Vector}, f, g)
406407
X_increments = reduce(hcat, get_increments.(dataset[1]))
407408
n, n_samples = size(X_increments)
408409

409-
loss_fn = (θ, _) -> begin
410+
loss_fn = (θ,
411+
_) -> begin
410412
gx = reduce(hcat,
411413
[[g(process[i, j], θ.p, dataset[2][i])^2 * Δt[i] for i in 1:n]
412414
for j in 1:n_samples])
@@ -658,7 +660,8 @@ function SciMLBase.__solve(
658660
# For weak training: higher sub_batch corresponds with a narrower confidence band/ increased certainty in the Weak solution.
659661
# For strong training: it means more strong paths to train over.
660662
# weak loss-> weak training is default solve mode.
661-
(; param_estim, sub_batch, strong_loss, moment_loss, chain, opt, autodiff, init_params, batch,
663+
(; param_estim, sub_batch, strong_loss, moment_loss,
664+
chain, opt, autodiff, init_params, batch,
662665
additional_loss, dataset, numensemble, data_sub_batch) = alg
663666
n_z = chain[1].in_dims - 1
664667
sde_phi, init_params = generate_phi(chain, t0, u0, init_params)
@@ -709,7 +712,8 @@ function SciMLBase.__solve(
709712
if moment_loss
710713
# min batch for L2 mean is sub samples of the dataset
711714
data_sub_batch = max(data_sub_batch, length(dataset[1]))
712-
DataMoments_loss, dataset_training_sets = generate_DataMoments_loss(
715+
DataMoments_loss,
716+
dataset_training_sets = generate_DataMoments_loss(
713717
dataset, n_z, sde_phi, f, g,
714718
autodiff, p, param_estim, data_sub_batch, train_type)
715719
end

test/NN_SDE_tests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ end
7676
z3 * sin((3 - 1 / 2) * π * t) / ((3 - 1 / 2) * π))
7777
end
7878
truncated_sol(
79-
u0, t, z1, z2, z3) = u0 *
80-
exp((α - β^2 / 2) * t + β * W_kkl(t, z1, z2, z3))
79+
u0, t, z1, z2, z3) = u0 *
80+
exp((α - β^2 / 2) * t + β * W_kkl(t, z1, z2, z3))
8181

8282
num_samples = 2000
8383
num_time_steps = dt
@@ -446,8 +446,8 @@ end
446446

447447
# estimated sde parameter tests (we trained with 15 observed solution paths).
448448
# absolute value taken for 2nd estimated parameter as loss for variance is independent of this parameter's direction.
449-
@test sol_1.estimated_params[1].ideal_p[1] rtol=2e-1
450-
@test abs(sol_1.estimated_params[2]).≈ideal_p[2] rtol=8e-2
451-
@test sol_2.estimated_params[1].ideal_p[1] rtol=2e-1
452-
@test abs(sol_2.estimated_params[2]).≈ideal_p[2] rtol=8e-2
449+
@test sol_1.estimated_params[1] .≈ ideal_p[1] rtol=2e-1
450+
@test abs(sol_1.estimated_params[2]) .≈ ideal_p[2] rtol=8e-2
451+
@test sol_2.estimated_params[1] .≈ ideal_p[1] rtol=2e-1
452+
@test abs(sol_2.estimated_params[2]) .≈ ideal_p[2] rtol=8e-2
453453
end

0 commit comments

Comments
 (0)