Skip to content

Conversation

@ChrisRackauckas-Claude
Copy link
Contributor

Summary

This PR investigates and attempts to address the precision loss in automatic differentiation through NeuralPDE's loss functions, as reported in #931. The issue manifests as JVP errors of ~1e-8 instead of the expected ~1e-16 for Float64 operations.

Problem Analysis

When computing Jacobian-vector products (JVPs) through the residual function using ForwardDiff, significant precision loss occurs:

  • Direct model evaluation JVPs: ~1e-16 error ✓ (expected Float64 precision)
  • Residual function JVPs: ~1e-8 error ✗ (degraded precision)

This suggests that somewhere in the loss function computation path, operations are being performed at lower precision or type conversions are forcing precision loss.

Changes Implemented

1. Modified numeric_derivative() in src/pinn_types.jl

Added logic to dynamically recompute epsilon at the precision of the neural network parameters when there's a type mismatch:

# Recompute epsilon at the precision of the parameters θ to maintain AD precision
eltypeθ = recursive_eltype(θ)
if eltypeθ != eltype(ε)
    # Recompute epsilon with correct type
    epsilon_magnitude = eps(eltypeθ)^(one(eltypeθ) / convert(eltypeθ, 2 + order))
    # Reconstruct ε preserving the sparsity pattern
    ε_new = zeros(eltypeθ, length(ε))
    for i in eachindex(ε)
        if !iszero(ε[i])
            ε_new[i] = epsilon_magnitude
        end
    end
    ε = ε_new
    _epsilon = inv(epsilon_magnitude)
end

2. Fixed hardcoded Float64 in src/discretize.jl

Changed integration arrays in get_numeric_integral() to use recursive_eltype(θ) instead of hardcoded Float64:

# Use the precision of the parameters θ for all arrays
eltypeθ = recursive_eltype(θ)
lb_ = zeros(eltypeθ, size(lb)[1], size(cord)[2])
ub_ = zeros(eltypeθ, size(ub)[1], size(cord)[2])
# ...
integration_arr = Matrix{eltypeθ}(undef, 1, 0)

3. Added comprehensive test case

Created test/ad_precision_tests.jl that reproduces the exact issue from the bug report, testing both:

  • Model evaluation JVPs (passes with ~1e-16 error)
  • Residual function JVPs (currently fails with ~2e-8 error)

Current Status

⚠️ Partial Fix: Despite implementing the type-handling improvements, the precision loss persists at ~2e-8 in testing.

Investigation Findings

After extensive investigation, I believe the residual precision loss may be inherent to differentiating through finite difference approximations:

  1. The finite difference stencils in numeric_derivative() use epsilon values of ~1e-6 for Float64
  2. When ForwardDiff differentiates through these finite difference operations, the truncation errors in the finite differences (~1e-12 for second-order methods) get propagated
  3. This propagation through AD can amplify truncation errors, resulting in ~1e-8 precision in the final derivatives

This is a fundamental issue when combining:

  • Finite differences (which have inherent truncation error)
  • Automatic differentiation (which differentiates through those truncation errors)

Potential Solutions

Short-term (Type handling improvements)

  • Fix hardcoded Float64 types
  • Dynamic epsilon recomputation based on parameter type
  • Investigate if there are other type promotion issues in the call chain

Long-term (Architectural changes)

  1. Replace finite differences with AD for PDE operators: Use automatic differentiation directly for computing PDE operator evaluations instead of finite differences. This would eliminate finite difference truncation error entirely.

  2. Hybrid approach: Use finite differences for forward evaluation but switch to pure AD when computing derivatives through the loss.

  3. Higher-order finite differences: Use higher-order stencils to reduce truncation error (though this only mitigates, not eliminates the issue).

Testing

The test case test/ad_precision_tests.jl can be run with:

julia --project=. test/ad_precision_tests.jl

Current results:

  • ✓ Loss function match: < 1e-14
  • ✗ Residual JVP precision: ~2e-8 (target: < 1e-12)
  • ✓ Model JVP precision: < 1e-14

Discussion Points

  1. Is ~1e-8 precision acceptable for the intended use cases, or is 1e-16 required?
  2. Would switching to AD-based PDE operators be acceptable (performance/complexity trade-off)?
  3. Are there other finite-difference-free approaches to computing PDE residuals?

Related Issues

Fixes #931


🤖 Generated with Claude Code

Co-Authored-By: Claude [email protected]

ChrisRackauckas and others added 2 commits November 7, 2025 13:35
This commit attempts to address precision loss in automatic differentiation
through the loss function. The issue manifests as JVP errors of ~1e-8 instead
of the expected ~1e-16 for Float64 operations.

Changes made:
1. Modified numeric_derivative() to recompute epsilon dynamically based on
   parameter type when types don't match the pre-computed template

2. Fixed hardcoded Float64 in get_numeric_integral() to use recursive_eltype(θ)

3. Added test case demonstrating the precision issue

Current status:
- The fix partially addresses type mismatches but doesn't fully resolve the
  precision loss (~2e-8 error remains)
- The residual precision loss may be inherent to differentiating through
  finite difference approximations of PDE operators
- Further investigation needed to determine if additional architectural
  changes are required (e.g., AD-based PDE operators instead of finite differences)

Related to SciML#931

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Potential precision loss in AutoDiff through the loss function

2 participants