@@ -363,7 +363,8 @@ function generate_DataMoments_loss(
363363
364364 # moment matching (MSE across time for 1st, 2nd moments) - assumes diffusion is a Gaussian at each timepoint
365365 # uses sample variance
366- return (θ, _) -> begin
366+ return (θ,
367+ _) -> begin
367368 sum (abs2,
368369 mean (process, dims = 2 ) .-
369370 mean .(Base. Fix2 (phi, θ).(sdephi_inputs))) / length (ts) +
@@ -406,7 +407,8 @@ function generate_EM_L2loss(dataset::Vector{<:Vector}, f, g)
406407 X_increments = reduce (hcat, get_increments .(dataset[1 ]))
407408 n, n_samples = size (X_increments)
408409
409- loss_fn = (θ, _) -> begin
410+ loss_fn = (θ,
411+ _) -> begin
410412 gx = reduce (hcat,
411413 [[g (process[i, j], θ. p, dataset[2 ][i])^ 2 * Δt[i] for i in 1 : n]
412414 for j in 1 : n_samples])
@@ -658,7 +660,8 @@ function SciMLBase.__solve(
658660 # For weak training: higher sub_batch corresponds with a narrower confidence band/ increased certainty in the Weak solution.
659661 # For strong training: it means more strong paths to train over.
660662 # weak loss-> weak training is default solve mode.
661- (; param_estim, sub_batch, strong_loss, moment_loss, chain, opt, autodiff, init_params, batch,
663+ (; param_estim, sub_batch, strong_loss, moment_loss,
664+ chain, opt, autodiff, init_params, batch,
662665 additional_loss, dataset, numensemble, data_sub_batch) = alg
663666 n_z = chain[1 ]. in_dims - 1
664667 sde_phi, init_params = generate_phi (chain, t0, u0, init_params)
@@ -709,7 +712,8 @@ function SciMLBase.__solve(
709712 if moment_loss
710713 # min batch for L2 mean is sub samples of the dataset
711714 data_sub_batch = max (data_sub_batch, length (dataset[1 ]))
712- DataMoments_loss, dataset_training_sets = generate_DataMoments_loss (
715+ DataMoments_loss,
716+ dataset_training_sets = generate_DataMoments_loss (
713717 dataset, n_z, sde_phi, f, g,
714718 autodiff, p, param_estim, data_sub_batch, train_type)
715719 end
0 commit comments