Skip to content

Commit 114b808

Browse files
committed
move restore into Checkpointer
1 parent 5fb306b commit 114b808

File tree

6 files changed

+180
-128
lines changed

6 files changed

+180
-128
lines changed

docs/src/checkpointer.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Checkpoints are a mix of HDF5 and JLD2 files and are typically saved in a
1212
!!! warning "Known limitations"
1313

1414
- The number of MPI processes has to remain the same across checkpoints
15-
- Restart files are generally not portable across machines, julia versions,
15+
- Restart files are generally not portable across machines, julia versions,
1616
and package versions
1717
- Adding/changing new component models will probably require adding/changing code
1818

@@ -56,7 +56,7 @@ If the model does not support directly reading a checkpoint, the `Checkpointer`
5656
module provides a straightforward way to add this feature.
5757
[`Checkpointer.restart!`](@ref) takes a coupled simulation, a `restart_dir`, and
5858
a `restart_t` and overwrites the content of the coupled simulation with what is
59-
in the checkpoint.
59+
in the checkpoint.
6060

6161
## Developer notes
6262

@@ -135,4 +135,5 @@ Types to watch for:
135135
ClimaCoupler.Checkpointer.restart!
136136
ClimaCoupler.Checkpointer.checkpoint_sims
137137
ClimaCoupler.Checkpointer.t_start_from_checkpoint
138+
ClimaCoupler.Checkpointer.restore!
138139
```

experiments/ClimaEarth/components/atmosphere/climaatmos.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ if pkgversion(CA) < v"0.28.6"
2020
CC.Adapt.@adapt_structure CA.RRTMGPInterface.RRTMGPModel
2121
end
2222

23-
include("../shared/restore.jl")
2423

2524
###
2625
### Functions required by ClimaCoupler.jl for an AtmosModelSimulation
@@ -140,7 +139,7 @@ end
140139

141140
function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache)
142141
comms_ctx = ClimaComms.context(sim.integrator.u.c)
143-
restore!(
142+
Checkpointer.restore!(
144143
Checkpointer.get_model_cache(sim),
145144
new_cache,
146145
comms_ctx;

experiments/ClimaEarth/components/land/climaland_bucket.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import ClimaCoupler: Checkpointer, FluxCalculator, Interfacer, FieldExchanger
1313
using NCDatasets
1414
include("climaland_helpers.jl")
1515

16-
include("../shared/restore.jl")
1716

1817
###
1918
### Functions required by ClimaCoupler.jl for a SurfaceModelSimulation
@@ -411,7 +410,7 @@ end
411410
function Checkpointer.restore_cache!(sim::BucketSimulation, new_cache)
412411
old_cache = Checkpointer.get_model_cache(sim)
413412
comms_ctx = ClimaComms.context(sim.model)
414-
restore!(
413+
Checkpointer.restore!(
415414
old_cache,
416415
new_cache,
417416
comms_ctx,

experiments/ClimaEarth/components/land/climaland_integrated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ end
533533
function Checkpointer.restore_cache!(sim::ClimaLandSimulation, new_cache)
534534
old_cache = Checkpointer.get_model_cache(sim)
535535
comms_ctx = ClimaComms.context(sim.model.soil)
536-
restore!(
536+
Checkpointer.restore!(
537537
old_cache,
538538
new_cache,
539539
comms_ctx,

experiments/ClimaEarth/components/shared/restore.jl

Lines changed: 0 additions & 120 deletions
This file was deleted.

src/Checkpointer.jl

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ import ClimaComms
99
import ClimaCore as CC
1010
import ClimaUtilities.Utils: sort_by_creation_time
1111
import ClimaUtilities.TimeManager: ITime, seconds
12+
import ClimaUtilities.TimeVaryingInputs: AbstractTimeVaryingInput
1213
import ..Interfacer
1314
import Dates
15+
import StaticArrays
1416

1517
import JLD2
1618

17-
export get_model_prog_state, checkpoint_model_state, checkpoint_sims
19+
export get_model_prog_state, checkpoint_model_state, checkpoint_sims, restore!
1820

1921
"""
2022
get_model_prog_state(sim::Interfacer.ComponentModelSimulation)
@@ -295,4 +297,175 @@ function remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx)
295297
return nothing
296298
end
297299

300+
"""
301+
restore!(v1, v2, comms_ctx; ignore)
302+
303+
Recursively traverse `v1` and `v2`, setting each field of `v1` with the
304+
corresponding field in `v2`. In this, ignore all the properties that have name
305+
within the `ignore` iterable.
306+
307+
This is intended to be used when restarting a simulation's cache object
308+
from a checkpoint.
309+
310+
`ignore` is useful when there are stateful properties, such as live pointers.
311+
"""
312+
function restore!(v1::T1, v2::T2, comms_ctx; name = "", ignore) where {T1, T2}
313+
# We pick fieldnames(T2) because v2 tend to be simpler (Array as opposed
314+
# to CuArray)
315+
fields = filter(x -> !(x in ignore), fieldnames(T2))
316+
# If there are no fields to restore, we check for consistency
317+
if isempty(fields)
318+
v1 == v2 || error("$v1 != $v2")
319+
else
320+
# Recursive case: restore each field
321+
for p in fields
322+
restore!(
323+
getfield(v1, p),
324+
getfield(v2, p),
325+
comms_ctx;
326+
name = "$(name).$(p)",
327+
ignore,
328+
)
329+
end
330+
end
331+
return nothing
332+
end
333+
334+
"""
335+
restore!(
336+
v1::Union{
337+
AbstractTimeVaryingInput,
338+
ClimaComms.AbstractCommsContext,
339+
ClimaComms.AbstractDevice,
340+
UnionAll,
341+
DataType,
342+
},
343+
v2::Union{
344+
AbstractTimeVaryingInput,
345+
ClimaComms.AbstractCommsContext,
346+
ClimaComms.AbstractDevice,
347+
UnionAll,
348+
DataType,
349+
},
350+
_comms_ctx;
351+
name,
352+
ignore,
353+
)
354+
355+
Ignore certain types that don't need to be restored.
356+
`UnionAll` and `DataType` are infinitely recursive, so we also ignore those.
357+
"""
358+
function restore!(
359+
v1::Union{
360+
AbstractTimeVaryingInput,
361+
ClimaComms.AbstractCommsContext,
362+
ClimaComms.AbstractDevice,
363+
UnionAll,
364+
DataType,
365+
},
366+
v2::Union{
367+
AbstractTimeVaryingInput,
368+
ClimaComms.AbstractCommsContext,
369+
ClimaComms.AbstractDevice,
370+
UnionAll,
371+
DataType,
372+
},
373+
_comms_ctx;
374+
name,
375+
ignore,
376+
)
377+
return nothing
378+
end
379+
380+
"""
381+
restore!(
382+
v1::Union{CC.DataLayouts.AbstractData, AbstractArray},
383+
v2::Union{CC.DataLayouts.AbstractData, AbstractArray},
384+
comms_ctx;
385+
name,
386+
ignore,
387+
)
388+
389+
For array-like objects, we move the original data (v2) to the
390+
device of the new data (v1). Then we copy the original data to
391+
the new object.
392+
"""
393+
function restore!(
394+
v1::Union{CC.DataLayouts.AbstractData, AbstractArray},
395+
v2::Union{CC.DataLayouts.AbstractData, AbstractArray},
396+
comms_ctx;
397+
name,
398+
ignore,
399+
)
400+
ArrayType =
401+
parent(v1) isa Array ? Array : ClimaComms.array_type(ClimaComms.device(comms_ctx))
402+
moved_to_device = ArrayType(parent(v2))
403+
404+
parent(v1) .= moved_to_device
405+
return nothing
406+
end
407+
408+
"""
409+
restore!(
410+
v1::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
411+
v2::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
412+
comms_ctx;
413+
name,
414+
ignore,
415+
)
416+
417+
Ensure that immutable objects have been initialized correctly,
418+
as they cannot be restored from a checkpoint.
419+
"""
420+
function restore!(
421+
v1::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
422+
v2::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
423+
comms_ctx;
424+
name,
425+
ignore,
426+
)
427+
v1 == v2 || error("$name is a immutable but it inconsistent ($(v1) != $(v2))")
428+
return nothing
429+
end
430+
431+
"""
432+
restore!(v1::Dict, v2::Dict, comms_ctx; name, ignore)
433+
434+
RRTMGP has some internal dictionaries, which we check for consistency.
435+
"""
436+
function restore!(v1::Dict, v2::Dict, comms_ctx; name, ignore)
437+
v1 == v2 || error("$name is inconsistent")
438+
return nothing
439+
end
440+
441+
"""
442+
restore!(
443+
v1::T1,
444+
v2::T2,
445+
comms_ctx;
446+
name,
447+
ignore,
448+
) where {
449+
T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
450+
T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
451+
}
452+
453+
Special case to compare time-related types to allow different timestamps during restore.
454+
"""
455+
function restore!(
456+
v1::T1,
457+
v2::T2,
458+
comms_ctx;
459+
name,
460+
ignore,
461+
) where {
462+
T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
463+
T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
464+
}
465+
if v1 != v2
466+
@warn "Time value differs in restart" field = name original = v2 new = v1
467+
end
468+
return nothing
469+
end
470+
298471
end # module

0 commit comments

Comments
 (0)