diff --git a/docs/src/checkpointer.md b/docs/src/checkpointer.md index 35bfcdc49f..c92208157d 100644 --- a/docs/src/checkpointer.md +++ b/docs/src/checkpointer.md @@ -12,7 +12,7 @@ Checkpoints are a mix of HDF5 and JLD2 files and are typically saved in a !!! warning "Known limitations" - The number of MPI processes has to remain the same across checkpoints - - Restart files are generally not portable across machines, julia versions, + - Restart files are generally not portable across machines, julia versions, and package versions - Adding/changing new component models will probably require adding/changing code @@ -56,7 +56,7 @@ If the model does not support directly reading a checkpoint, the `Checkpointer` module provides a straightforward way to add this feature. [`Checkpointer.restart!`](@ref) takes a coupled simulation, a `restart_dir`, and a `restart_t` and overwrites the content of the coupled simulation with what is -in the checkpoint. +in the checkpoint. ## Developer notes @@ -135,4 +135,5 @@ Types to watch for: ClimaCoupler.Checkpointer.restart! ClimaCoupler.Checkpointer.checkpoint_sims ClimaCoupler.Checkpointer.t_start_from_checkpoint + ClimaCoupler.Checkpointer.restore! ``` diff --git a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl index 557c820832..141f7e8fe5 100644 --- a/experiments/ClimaEarth/components/atmosphere/climaatmos.jl +++ b/experiments/ClimaEarth/components/atmosphere/climaatmos.jl @@ -20,7 +20,6 @@ if pkgversion(CA) < v"0.28.6" CC.Adapt.@adapt_structure CA.RRTMGPInterface.RRTMGPModel end -include("../shared/restore.jl") ### ### Functions required by ClimaCoupler.jl for an AtmosModelSimulation @@ -140,7 +139,7 @@ end function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache) comms_ctx = ClimaComms.context(sim.integrator.u.c) - restore!( + Checkpointer.restore!( Checkpointer.get_model_cache(sim), new_cache, comms_ctx; @@ -409,14 +408,14 @@ Interfacer.get_field(sim::ClimaAtmosSimulation, ::Val{:height_sfc}) = function Interfacer.get_field(sim::ClimaAtmosSimulation, ::Val{:u_int}) # NOTE: This calculation is copied from ClimaAtmos (and is allocating! Fix me if you can!) int_local_geometry_values = - Fields.level(Fields.local_geometry_field(sim.integrator.u.c), 1) + CC.Fields.level(CC.Fields.local_geometry_field(sim.integrator.u.c), 1) int_u_values = CC.Spaces.level(sim.integrator.p.precomputed.ᶜu, 1) return CA.projected_vector_data.(CA.CT1, int_u_values, int_local_geometry_values) end function Interfacer.get_field(sim::ClimaAtmosSimulation, ::Val{:v_int}) # NOTE: This calculation is copied from ClimaAtmos (and is allocating! Fix me if you can!) int_local_geometry_values = - Fields.level(Fields.local_geometry_field(sim.integrator.u.c), 1) + CC.Fields.level(CC.Fields.local_geometry_field(sim.integrator.u.c), 1) int_u_values = CC.Spaces.level(sim.integrator.p.precomputed.ᶜu, 1) return CA.projected_vector_data.(CA.CT2, int_u_values, int_local_geometry_values) end diff --git a/experiments/ClimaEarth/components/land/climaland_bucket.jl b/experiments/ClimaEarth/components/land/climaland_bucket.jl index 5587d662db..410f324c1a 100644 --- a/experiments/ClimaEarth/components/land/climaland_bucket.jl +++ b/experiments/ClimaEarth/components/land/climaland_bucket.jl @@ -13,7 +13,6 @@ import ClimaCoupler: Checkpointer, FluxCalculator, Interfacer, FieldExchanger using NCDatasets include("climaland_helpers.jl") -include("../shared/restore.jl") ### ### Functions required by ClimaCoupler.jl for a SurfaceModelSimulation @@ -411,7 +410,7 @@ end function Checkpointer.restore_cache!(sim::BucketSimulation, new_cache) old_cache = Checkpointer.get_model_cache(sim) comms_ctx = ClimaComms.context(sim.model) - restore!( + Checkpointer.restore!( old_cache, new_cache, comms_ctx, diff --git a/experiments/ClimaEarth/components/land/climaland_helpers.jl b/experiments/ClimaEarth/components/land/climaland_helpers.jl index 1d530b174c..a8f5371ba6 100644 --- a/experiments/ClimaEarth/components/land/climaland_helpers.jl +++ b/experiments/ClimaEarth/components/land/climaland_helpers.jl @@ -75,7 +75,7 @@ function make_land_domain( vertmesh = CC.Meshes.IntervalMesh( vertdomain, - ClimaCore.Meshes.GeneralizedExponentialStretching{FT}(dz_tuple[1], dz_tuple[2]); + CC.Meshes.GeneralizedExponentialStretching{FT}(dz_tuple[1], dz_tuple[2]); nelems = nelements_vert, reverse_mode = true, ) diff --git a/experiments/ClimaEarth/components/land/climaland_integrated.jl b/experiments/ClimaEarth/components/land/climaland_integrated.jl index 7e60ad0b0e..6d62219fa7 100644 --- a/experiments/ClimaEarth/components/land/climaland_integrated.jl +++ b/experiments/ClimaEarth/components/land/climaland_integrated.jl @@ -533,7 +533,7 @@ end function Checkpointer.restore_cache!(sim::ClimaLandSimulation, new_cache) old_cache = Checkpointer.get_model_cache(sim) comms_ctx = ClimaComms.context(sim.model.soil) - restore!( + Checkpointer.restore!( old_cache, new_cache, comms_ctx, diff --git a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl index d86afa34b4..4a318928d9 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_ocean.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_ocean.jl @@ -180,7 +180,7 @@ end function Checkpointer.restore_cache!(sim::PrescribedOceanSimulation, new_cache) old_cache = Checkpointer.get_model_cache(sim) for p in propertynames(old_cache) - if getproperty(old_cache, p) isa Field + if getproperty(old_cache, p) isa CC.Fields.Field ArrayType = ClimaComms.array_type(getproperty(old_cache, p)) parent(getproperty(old_cache, p)) .= ArrayType(parent(getproperty(new_cache, p))) diff --git a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl index fecd8a224b..416f538312 100644 --- a/experiments/ClimaEarth/components/ocean/prescr_seaice.jl +++ b/experiments/ClimaEarth/components/ocean/prescr_seaice.jl @@ -359,7 +359,7 @@ end function Checkpointer.restore_cache!(sim::PrescribedIceSimulation, new_cache) old_cache = Checkpointer.get_model_cache(sim) for p in propertynames(old_cache) - if getproperty(old_cache, p) isa Field + if getproperty(old_cache, p) isa CC.Fields.Field ArrayType = ClimaComms.array_type(getproperty(old_cache, p)) parent(getproperty(old_cache, p)) .= ArrayType(parent(getproperty(new_cache, p))) diff --git a/experiments/ClimaEarth/components/shared/restore.jl b/experiments/ClimaEarth/components/shared/restore.jl deleted file mode 100644 index 551a3722dd..0000000000 --- a/experiments/ClimaEarth/components/shared/restore.jl +++ /dev/null @@ -1,120 +0,0 @@ -# Define shared methods to allow reading back a saved cache - -import ClimaComms -import ClimaCore -import ClimaCore: DataLayouts, Fields, Geometry -import ClimaCore.Fields: Field, FieldVector, field_values -import ClimaCore.DataLayouts: AbstractData -import ClimaCore.Geometry: AxisTensor -import ClimaCore.Spaces: AbstractSpace -import ClimaUtilities.TimeVaryingInputs: AbstractTimeVaryingInput -import StaticArrays -import NCDatasets -import Dates - -""" - restore!(v1, v2, comms_ctx; ignore) - -Recursively traverse `v1` and `v2`, setting each field of `v1` with the -corresponding field in `v2`. In this, ignore all the properties that have name -within the `ignore` iterable. - -`ignore` is useful when there are stateful properties, such as live pointers. -""" -function restore!(v1::T1, v2::T2, comms_ctx; name = "", ignore) where {T1, T2} - # We pick fieldnames(T2) because v2 tend to be simpler (Array as opposed - # to CuArray) - fields = filter(x -> !(x in ignore), fieldnames(T2)) - if isempty(fields) - v1 == v2 || error("$v1 != $v2") - else - # Recursive case - for p in fields - restore!( - getfield(v1, p), - getfield(v2, p), - comms_ctx; - name = "$(name).$(p)", - ignore, - ) - end - end - return nothing -end - -# Ignoring certain types that don't need to be restored -# UnionAll and DataType are infinitely recursive, so we also ignore those -function restore!( - v1::Union{ - AbstractTimeVaryingInput, - ClimaComms.AbstractCommsContext, - ClimaComms.AbstractDevice, - UnionAll, - DataType, - }, - v2::Union{ - AbstractTimeVaryingInput, - ClimaComms.AbstractCommsContext, - ClimaComms.AbstractDevice, - UnionAll, - DataType, - }, - _comms_ctx; - name, - ignore, -) - return nothing -end - -function restore!( - v1::Union{AbstractData, AbstractArray}, - v2::Union{AbstractData, AbstractArray}, - comms_ctx; - name, - ignore, -) - ArrayType = - parent(v1) isa Array ? Array : ClimaComms.array_type(ClimaComms.device(comms_ctx)) - moved_to_device = ArrayType(parent(v2)) - - parent(v1) .= moved_to_device - return nothing -end - -function restore!( - v1::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol}, - v2::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol}, - comms_ctx; - name, - ignore, -) - v1 == v2 || error("$name is a immutable but it inconsistent ($(v1) != $(v2))") - return nothing -end - -function restore!(v1::Dict, v2::Dict, comms_ctx; name, ignore) - # RRTGMP has some internal dictionaries - v1 == v2 || error("$name is inconsistent") - return nothing -end - -""" - restore!(v1::T1, v2::T2, comms_ctx; name = "", ignore) where {T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}, T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}} - -Special case for time-related types to allow different timestamps during restore. -""" -function restore!( - v1::T1, - v2::T2, - comms_ctx; - name, - ignore, -) where { - T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}, - T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}, -} - if v1 != v2 - @warn "Time value differs in restart" field = name original = v2 new = v1 - end - return nothing -end diff --git a/experiments/ClimaEarth/test/compare.jl b/experiments/ClimaEarth/test/compare.jl index 3b63b0d12b..4ef2c50ac1 100644 --- a/experiments/ClimaEarth/test/compare.jl +++ b/experiments/ClimaEarth/test/compare.jl @@ -3,12 +3,7 @@ import ClimaComms import ClimaAtmos as CA -import ClimaCore -import ClimaCore: DataLayouts, Fields, Geometry -import ClimaCore.Fields: Field, FieldVector, field_values -import ClimaCore.DataLayouts: AbstractData -import ClimaCore.Geometry: AxisTensor -import ClimaCore.Spaces: AbstractSpace +import ClimaCore as CC import NCDatasets """ @@ -55,8 +50,8 @@ function compare( name = "", ignore = Set([:rc]), ) where { - T1 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, - T2 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache}, + T1 <: Union{CC.Fields.FieldVector, CC.Spaces.AbstractSpace, NamedTuple, CA.AtmosCache}, + T2 <: Union{CC.Fields.FieldVector, CC.Spaces.AbstractSpace, NamedTuple, CA.AtmosCache}, } pass = true return _compare(pass, v1, v2; name, ignore) @@ -100,11 +95,16 @@ function _compare(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDat return pass end -function _compare(v1::T, v2::T; name, ignore) where {T <: Field{<:AbstractData{<:Real}}} +function _compare( + v1::T, + v2::T; + name, + ignore, +) where {T <: CC.Fields.Field{<:CC.DataLayouts.AbstractData{<:Real}}} return _compare(parent(v1), parent(v2); name, ignore) end -function _compare(pass, v1::T, v2::T; name, ignore) where {T <: AbstractData} +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: CC.DataLayouts.AbstractData} return pass && _compare(parent(v1), parent(v2); name, ignore) end diff --git a/src/Checkpointer.jl b/src/Checkpointer.jl index 34a827bed4..b525d08646 100644 --- a/src/Checkpointer.jl +++ b/src/Checkpointer.jl @@ -9,12 +9,14 @@ import ClimaComms import ClimaCore as CC import ClimaUtilities.Utils: sort_by_creation_time import ClimaUtilities.TimeManager: ITime, seconds +import ClimaUtilities.TimeVaryingInputs: AbstractTimeVaryingInput import ..Interfacer import Dates +import StaticArrays import JLD2 -export get_model_prog_state, checkpoint_model_state, checkpoint_sims +export get_model_prog_state, checkpoint_model_state, checkpoint_sims, restore! """ get_model_prog_state(sim::Interfacer.ComponentModelSimulation) @@ -295,4 +297,175 @@ function remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx) return nothing end +""" + restore!(v1, v2, comms_ctx; name = "", ignore = Set()) + +Recursively traverse `v1` and `v2`, setting each field of `v1` with the +corresponding field in `v2`. In this, ignore all the properties that have name +within the `ignore` iterable. + +This is intended to be used when restarting a simulation's cache object +from a checkpoint. + +`ignore` is useful when there are stateful properties, such as live pointers. +""" +function restore!(v1::T1, v2::T2, comms_ctx; name = "", ignore = Set()) where {T1, T2} + # We pick fieldnames(T2) because v2 tend to be simpler (Array as opposed + # to CuArray) + fields = filter(x -> !(x in ignore), fieldnames(T2)) + # If there are no fields to restore, we check for consistency + if isempty(fields) + v1 == v2 || error("$v1 != $v2") + else + # Recursive case: restore each field + for p in fields + restore!( + getfield(v1, p), + getfield(v2, p), + comms_ctx; + name = "$(name).$(p)", + ignore, + ) + end + end + return nothing +end + +""" + restore!( + v1::Union{ + AbstractTimeVaryingInput, + ClimaComms.AbstractCommsContext, + ClimaComms.AbstractDevice, + UnionAll, + DataType, + }, + v2::Union{ + AbstractTimeVaryingInput, + ClimaComms.AbstractCommsContext, + ClimaComms.AbstractDevice, + UnionAll, + DataType, + }, + _comms_ctx; + name = "", + ignore = Set(), + ) + +Ignore certain types that don't need to be restored. +`UnionAll` and `DataType` are infinitely recursive, so we also ignore those. +""" +function restore!( + v1::Union{ + AbstractTimeVaryingInput, + ClimaComms.AbstractCommsContext, + ClimaComms.AbstractDevice, + UnionAll, + DataType, + }, + v2::Union{ + AbstractTimeVaryingInput, + ClimaComms.AbstractCommsContext, + ClimaComms.AbstractDevice, + UnionAll, + DataType, + }, + _comms_ctx; + name = "", + ignore = Set(), +) + return nothing +end + +""" + restore!( + v1::Union{CC.DataLayouts.AbstractData, AbstractArray}, + v2::Union{CC.DataLayouts.AbstractData, AbstractArray}, + comms_ctx; + name = "", + ignore = Set(), + ) + +For array-like objects, we move the original data (v2) to the +device of the new data (v1). Then we copy the original data to +the new object. +""" +function restore!( + v1::Union{CC.DataLayouts.AbstractData, AbstractArray}, + v2::Union{CC.DataLayouts.AbstractData, AbstractArray}, + comms_ctx; + name = "", + ignore = Set(), +) + ArrayType = + parent(v1) isa Array ? Array : ClimaComms.array_type(ClimaComms.device(comms_ctx)) + moved_to_device = ArrayType(parent(v2)) + + parent(v1) .= moved_to_device + return nothing +end + +""" + restore!( + v1::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol}, + v2::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol}, + comms_ctx; + name = "", + ignore = Set(), + ) + +Ensure that immutable objects have been initialized correctly, +as they cannot be restored from a checkpoint. +""" +function restore!( + v1::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol}, + v2::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol}, + comms_ctx; + name = "", + ignore = Set(), +) + v1 == v2 || error("$name is immutable but it inconsistent ($(v1) != $(v2))") + return nothing +end + +""" + restore!(v1::Dict, v2::Dict, comms_ctx; name = "", ignore = Set()) + +RRTMGP has some internal dictionaries, which we check for consistency. +""" +function restore!(v1::Dict, v2::Dict, comms_ctx; name = "", ignore = Set()) + v1 == v2 || error("$name is inconsistent") + return nothing +end + +""" + restore!( + v1::T1, + v2::T2, + comms_ctx; + name = "", + ignore = Set(), + ) where { + T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}, + T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}, + } + +Special case to compare time-related types to allow different timestamps during restore. +""" +function restore!( + v1::T1, + v2::T2, + comms_ctx; + name = "", + ignore = Set(), +) where { + T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}, + T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond}, +} + if v1 != v2 + @warn "Time value differs in restart" field = name original = v2 new = v1 + end + return nothing +end + end # module diff --git a/test/checkpointer_tests.jl b/test/checkpointer_tests.jl index ff6df857a7..b1e6068f80 100644 --- a/test/checkpointer_tests.jl +++ b/test/checkpointer_tests.jl @@ -1,10 +1,18 @@ -import Test: @test, @testset +using Test import ClimaComms ClimaComms.@import_required_backends import ClimaCore as CC import ClimaCoupler: Checkpointer, Interfacer +import StaticArrays +import Dates -FT = Float64 +const FT = Float64 +const space_checkpointer = CC.CommonSpaces.CubedSphereSpace( + FT; + radius = FT(6.371e6), # in meters + n_quad_points = 4, + h_elem = 4, +) struct DummySimulation{S} <: Interfacer.AtmosModelSimulation state::S @@ -12,13 +20,7 @@ end Checkpointer.get_model_prog_state(sim::DummySimulation) = sim.state @testset "get_model_prog_state" begin - boundary_space = CC.CommonSpaces.CubedSphereSpace( - FT; - radius = FT(6.371e6), # in meters - n_quad_points = 4, - h_elem = 4, - ) - sim = DummySimulation((; T = ones(boundary_space))) + sim = DummySimulation((; T = ones(space_checkpointer))) @test Checkpointer.get_model_prog_state(sim) == sim.state sim2 = Interfacer.SurfaceStub([]) @@ -27,30 +29,158 @@ end @testset "checkpoint_model_state, restart_model_state!" begin comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded()) - boundary_space = CC.CommonSpaces.CubedSphereSpace( - FT; - comms_ctx, - radius = FT(6.371e6), # in meters - n_quad_points = 4, - h_elem = 4, - ) t = 1 prev_checkpoint_t = -1 # old sim run - sim = DummySimulation(CC.Fields.FieldVector(T = ones(boundary_space))) + sim = DummySimulation(CC.Fields.FieldVector(T = ones(space_checkpointer))) + + dir = mktempdir(; prefix = "test_checkpoint_") Checkpointer.checkpoint_model_state( sim, comms_ctx, t, prev_checkpoint_t, - output_dir = "test_checkpoint", + output_dir = dir, ) # new sim run - sim_new = DummySimulation(CC.Fields.FieldVector(T = zeros(boundary_space))) - Checkpointer.restart_model_state!(sim_new, comms_ctx, t, input_dir = "test_checkpoint") + sim_new = DummySimulation(CC.Fields.FieldVector(T = zeros(space_checkpointer))) + input_file = joinpath(dir, "checkpoint_$(nameof(sim_new))_$t.hdf5") + Checkpointer.restart_model_state!(sim_new, input_file, comms_ctx) @test sim_new.state.T == sim.state.T # remove checkpoint directory - rm("./test_checkpoint/", force = true, recursive = true) + rm(dir, force = true, recursive = true) +end + +@testset "restore! for different types" begin + comms_ctx = ClimaComms.context() + + # Test restore! for arrays + v1 = [1.0, 2.0, 3.0] + v2 = [4.0, 5.0, 6.0] + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + # Test restore! for ClimaCore data layouts + v1 = CC.Fields.field_values(ones(space_checkpointer)) + v2 = CC.Fields.field_values(zeros(space_checkpointer)) + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + # Test restore! for StaticArrays + v1 = StaticArrays.SVector{3, Float64}(1.0, 2.0, 3.0) + v2 = StaticArrays.SVector{3, Float64}(1.0, 2.0, 3.0) + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + v3 = StaticArrays.SVector{3, Float64}(4.0, 5.0, 6.0) + @test_throws ErrorException Checkpointer.restore!(v1, v3, comms_ctx) + + # Test restore! for Numbers + v1 = 42 + v2 = 42 + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + v3 = 43 + @test_throws ErrorException Checkpointer.restore!(v1, v3, comms_ctx) + + # Test restore! for UnitRange + v1 = 1:5 + v2 = 1:5 + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + v3 = 1:6 + @test_throws ErrorException Checkpointer.restore!(v1, v3, comms_ctx) + + # Test restore! for Symbol + v1 = :test + v2 = :test + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + v3 = :other + @test_throws ErrorException Checkpointer.restore!(v1, v3, comms_ctx) + + # Test restore! for Dict + v1 = Dict(:a => 1, :b => 2) + v2 = Dict(:a => 1, :b => 2) + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + v3 = Dict(:a => 1, :b => 3) + @test_throws ErrorException Checkpointer.restore!(v1, v3, comms_ctx) + + # Test restore! for Dates types + v1 = Dates.DateTime(2000, 1, 1) + v2 = Dates.DateTime(2000, 1, 1) + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + v3 = Dates.DateTime(2000, 1, 2) + @test_logs (:warn, "Time value differs in restart") Checkpointer.restore!( + v1, + v3, + comms_ctx, + ) + + # Test restore! for Dates.UTInstant + v1 = Dates.UTInstant(Dates.Minute(0)) + v2 = Dates.UTInstant(Dates.Minute(0)) + Checkpointer.restore!(v1, v3, comms_ctx) + @test v1 == v2 + + v3 = Dates.UTInstant(Dates.Minute(1)) + @test_logs (:warn, "Time value differs in restart") Checkpointer.restore!( + v1, + v3, + comms_ctx, + ) + + # Test restore! for Dates.Millisecond + v1 = Dates.Millisecond(1000) + v2 = Dates.Millisecond(1000) + Checkpointer.restore!(v1, v2, comms_ctx) + @test v1 == v2 + + v3 = Dates.Millisecond(2000) + @test_logs (:warn, "Time value differs in restart") Checkpointer.restore!( + v1, + v3, + comms_ctx, + ) + + # Test restore! does nothing for comms contexts, data types + v1 = ClimaComms.context() + v2 = ClimaComms.context() + @test isnothing(Checkpointer.restore!(v1, v2, comms_ctx)) + + v1 = Float64 + v2 = Float32 + @test isnothing(Checkpointer.restore!(v1, v2, comms_ctx)) + + # Test restore! for structs with fields + struct TestStructError + a::Float64 + end + v1 = TestStructError(1.0) + v2 = TestStructError(5.0) + @test_throws ErrorException Checkpointer.restore!(v1, v2, comms_ctx) + @test v1.a != v2.a + + # Test restore! with ignore parameter + struct TestStructWithIgnore + a::Ref{Float64} + b::Int + c::Vector{Float64} + end + + v1 = TestStructWithIgnore(Ref(1.0), 2, [3.0, 4.0]) + v2 = TestStructWithIgnore(Ref(1.0), 6, [7.0, 8.0]) + Checkpointer.restore!(v1, v2, comms_ctx; ignore = Set([:b])) + @test v1.a[] == v2.a[] # Ref doesn't get restored + @test v1.b != v2.b # Int doesn't get restored + @test v1.c == v2.c end