From 666b50986c5902146076e631b133eb89ca8546d7 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 7 Nov 2025 13:17:16 -0500 Subject: [PATCH 1/2] WIP: Attempt to fix AD precision loss (Issue #931) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit attempts to address precision loss in automatic differentiation through the loss function. The issue manifests as JVP errors of ~1e-8 instead of the expected ~1e-16 for Float64 operations. Changes made: 1. Modified numeric_derivative() to recompute epsilon dynamically based on parameter type when types don't match the pre-computed template 2. Fixed hardcoded Float64 in get_numeric_integral() to use recursive_eltype(θ) 3. Added test case demonstrating the precision issue Current status: - The fix partially addresses type mismatches but doesn't fully resolve the precision loss (~2e-8 error remains) - The residual precision loss may be inherent to differentiating through finite difference approximations of PDE operators - Further investigation needed to determine if additional architectural changes are required (e.g., AD-based PDE operators instead of finite differences) Related to #931 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/discretize.jl | 9 ++- src/pinn_types.jl | 17 ++++ test/ad_precision_tests.jl | 158 +++++++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 3 deletions(-) create mode 100644 test/ad_precision_tests.jl diff --git a/src/discretize.jl b/src/discretize.jl index 433cccc15..69ec699b1 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -310,8 +310,11 @@ function get_numeric_integral(pinnrep::PINNRepresentation) return sol end - lb_ = zeros(size(lb)[1], size(cord)[2]) - ub_ = zeros(size(ub)[1], size(cord)[2]) + # Use the precision of the parameters θ for all arrays + eltypeθ = recursive_eltype(θ) + + lb_ = zeros(eltypeθ, size(lb)[1], size(cord)[2]) + ub_ = zeros(eltypeθ, size(ub)[1], size(cord)[2]) for (i, l) in enumerate(lb) if l isa Number @ignore_derivatives lb_[i, :] .= l @@ -328,7 +331,7 @@ function get_numeric_integral(pinnrep::PINNRepresentation) nothing, u, nothing) end end - integration_arr = Matrix{Float64}(undef, 1, 0) + integration_arr = Matrix{eltypeθ}(undef, 1, 0) for i in 1:size(cord, 2) integration_arr = hcat(integration_arr, integration_(cord[:, i], lb_[:, i], ub_[:, i], θ)) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index db9e64e5b..64892f065 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -367,6 +367,23 @@ get_u() = (cord, θ, phi) -> phi(cord, θ) function numeric_derivative(phi, u, x, εs, order, θ) ε = εs[order] _epsilon = inv(first(ε[ε .!= zero(ε)])) + + # Recompute epsilon at the precision of the parameters θ to maintain AD precision + eltypeθ = recursive_eltype(θ) + if eltypeθ != eltype(ε) + # Recompute epsilon with correct type + epsilon_magnitude = eps(eltypeθ)^(one(eltypeθ) / convert(eltypeθ, 2 + order)) + # Reconstruct ε preserving the sparsity pattern + ε_new = zeros(eltypeθ, length(ε)) + for i in eachindex(ε) + if !iszero(ε[i]) + ε_new[i] = epsilon_magnitude + end + end + ε = ε_new + _epsilon = inv(epsilon_magnitude) + end + ε = ε |> safe_get_device(x) # any(x->x!=εs[1],εs) diff --git a/test/ad_precision_tests.jl b/test/ad_precision_tests.jl new file mode 100644 index 000000000..b6c2bd7f6 --- /dev/null +++ b/test/ad_precision_tests.jl @@ -0,0 +1,158 @@ +using Test +using NeuralPDE, Lux, Random, ComponentArrays +using Optimization +using OptimizationOptimisers +using DomainSets: Interval +using ModelingToolkit: @parameters, @variables, PDESystem, Differential +using Printf + +# Test for issue #931: Precision loss in AutoDiff through the loss function +@testset "AD Precision Tests (Issue #931)" begin + using ForwardDiff, DifferentiationInterface, LinearAlgebra + + @parameters t x y + @variables u(..) + Dxx = Differential(x)^2 + Dyy = Differential(y)^2 + Dt = Differential(t) + t_min = 0.0 + t_max = 2.0 + x_min = 0.0 + x_max = 2.0 + y_min = 0.0 + y_max = 2.0 + + # 2D PDE + eq = Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y)) + + analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t) + # Initial and boundary conditions + bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y), + u(t, x_min, y) ~ analytic_sol_func(t, x_min, y), + u(t, x_max, y) ~ analytic_sol_func(t, x_max, y), + u(t, x, y_min) ~ analytic_sol_func(t, x, y_min), + u(t, x, y_max) ~ analytic_sol_func(t, x, y_max)] + + # Space and time domains + domains = [t ∈ Interval(t_min, t_max), + x ∈ Interval(x_min, x_max), + y ∈ Interval(y_min, y_max)] + + # Neural network + inner = 25 + chain = Chain(Dense(3, inner, σ), Dense(inner, 1)) + + strategy = GridTraining(0.1) + ps, st = Lux.setup(Random.default_rng(), chain) + ps = ps |> ComponentArray .|> Float64 + discretization = PhysicsInformedNN(chain, strategy; init_params = ps) + + @named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y)]) + prob = discretize(pde_system, discretization) + symprob = symbolic_discretize(pde_system, discretization) + + # Get the full residual function + function get_residual_vector(pinnrep, loss_function, train_set) + eltypeθ = NeuralPDE.recursive_eltype(pinnrep.flat_init_params) + train_set = train_set |> NeuralPDE.safe_get_device(pinnrep.init_params) |> + NeuralPDE.EltypeAdaptor{eltypeθ}() + return θ -> loss_function(train_set, θ) + end + + function get_full_residual(prob, symprob) + # Get training sets + (; domains, eqs, bcs, dict_indvars, dict_depvars, strategy) = symprob + eltypeθ = NeuralPDE.recursive_eltype(symprob.flat_init_params) + adaptor = NeuralPDE.EltypeAdaptor{eltypeθ}() + + train_sets = NeuralPDE.generate_training_sets(domains, strategy.dx, eqs, bcs, + eltypeθ, + dict_indvars, dict_depvars) + pde_train_sets, bcs_train_sets = train_sets |> adaptor + + # Get residuals + pde_residuals = [get_residual_vector(symprob, _loss, _set) + for (_loss, _set) in zip( + symprob.loss_functions.datafree_pde_loss_functions, pde_train_sets)] + bc_residuals = [get_residual_vector(symprob, _loss, _set) + for (_loss, _set) in zip( + symprob.loss_functions.datafree_bc_loss_functions, bcs_train_sets)] + + # Setup adaloss weights (assuming NonAdaptiveLoss) + num_pde_losses = length(pde_residuals) + num_bc_losses = length(bc_residuals) + adaloss = symprob.adaloss + adaloss_T = eltype(adaloss.pde_loss_weights) + + function full_residual(θ) + pde_losses = [pde_residual(θ) for pde_residual in pde_residuals] + bc_losses = [bc_residual(θ) for bc_residual in bc_residuals] + + weighted_pde_losses = sqrt.(adaloss.pde_loss_weights) .* pde_losses ./ + sqrt.(length.(pde_losses)) + weighted_bc_losses = sqrt.(adaloss.bc_loss_weights) .* bc_losses ./ + sqrt.(length.(bc_losses)) + + full_res = hcat(hcat(weighted_pde_losses...), hcat(weighted_bc_losses...)) + return full_res + end + + return full_residual + end + + residual = get_full_residual(prob, symprob) + loss = θ -> sum(abs2, residual(θ)) + loss_neuralpdes = θ -> prob.f(θ, prob.p) + + θ = prob.u0 + + # Test 1: Sanity check that our loss matches NeuralPDE's loss + rel_err = abs(loss_neuralpdes(θ) - loss(θ)) / abs(loss_neuralpdes(θ)) + @test rel_err < 1e-14 + println("Loss function match error: $rel_err") + + # Test 2: Check JVP precision on the residual function + v = randn(length(θ)) + J_fwd = ForwardDiff.jacobian(residual, θ) + jvp_explicit = J_fwd * v + jvp_pushforward = DifferentiationInterface.pushforward( + residual, + AutoForwardDiff(), + θ, + (v,), + )[1] + + jvp_error = norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit) + println("AutoForwardDiff error on residual jvp: $jvp_error") + + # This is the key test: the JVP error should be at Float64 precision (< 1e-14) + # Previously this would be ~1e-8 due to precision loss + @test jvp_error < 1e-12 + + # Test 3: Verify model evaluation also maintains precision + function get_quadpoints(symprob, strategy) + (; domains, eqs, dict_indvars, dict_depvars) = symprob + eltypeθ = NeuralPDE.recursive_eltype(symprob.flat_init_params) + + train_sets = hcat(NeuralPDE.generate_training_sets(domains, strategy.dx, eqs, [], + eltypeθ, + dict_indvars, dict_depvars)[1]...) + return train_sets + end + + x_points = get_quadpoints(symprob, strategy) + fun = ps -> chain(x_points, ps, st)[1] + J_fwd_model = ForwardDiff.jacobian(fun, θ) + jvp_explicit_model = J_fwd_model * v + jvp_pushforward_model = DifferentiationInterface.pushforward( + fun, + AutoForwardDiff(), + θ, + (v,), + )[1] + + model_jvp_error = norm(jvp_explicit_model - jvp_pushforward_model[:]) / + norm(jvp_explicit_model) + println("AutoForwardDiff error on model jvp: $model_jvp_error") + @test model_jvp_error < 1e-14 +end From 32f6468cc0d12ec92bd483cb0952279fd09a8971 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Thu, 13 Nov 2025 20:13:17 -0500 Subject: [PATCH 2/2] Apply JuliaFormatter with SciMLStyle --- test/ad_precision_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad_precision_tests.jl b/test/ad_precision_tests.jl index b6c2bd7f6..b7026c80c 100644 --- a/test/ad_precision_tests.jl +++ b/test/ad_precision_tests.jl @@ -119,7 +119,7 @@ using Printf residual, AutoForwardDiff(), θ, - (v,), + (v,) )[1] jvp_error = norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit) @@ -148,7 +148,7 @@ using Printf fun, AutoForwardDiff(), θ, - (v,), + (v,) )[1] model_jvp_error = norm(jvp_explicit_model - jvp_pushforward_model[:]) /