@@ -9,12 +9,14 @@ import ClimaComms
99import ClimaCore as CC
1010import ClimaUtilities. Utils: sort_by_creation_time
1111import ClimaUtilities. TimeManager: ITime, seconds
12+ import ClimaUtilities. TimeVaryingInputs: AbstractTimeVaryingInput
1213import .. Interfacer
1314import Dates
15+ import StaticArrays
1416
1517import JLD2
1618
17- export get_model_prog_state, checkpoint_model_state, checkpoint_sims
19+ export get_model_prog_state, checkpoint_model_state, checkpoint_sims, restore!
1820
1921"""
2022 get_model_prog_state(sim::Interfacer.ComponentModelSimulation)
@@ -295,4 +297,175 @@ function remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx)
295297 return nothing
296298end
297299
300+ """
301+ restore!(v1, v2, comms_ctx; ignore)
302+
303+ Recursively traverse `v1` and `v2`, setting each field of `v1` with the
304+ corresponding field in `v2`. In this, ignore all the properties that have name
305+ within the `ignore` iterable.
306+
307+ This is intended to be used when restarting a simulation's cache object
308+ from a checkpoint.
309+
310+ `ignore` is useful when there are stateful properties, such as live pointers.
311+ """
312+ function restore! (v1:: T1 , v2:: T2 , comms_ctx; name = " " , ignore) where {T1, T2}
313+ # We pick fieldnames(T2) because v2 tend to be simpler (Array as opposed
314+ # to CuArray)
315+ fields = filter (x -> ! (x in ignore), fieldnames (T2))
316+ # If there are no fields to restore, we check for consistency
317+ if isempty (fields)
318+ v1 == v2 || error (" $v1 != $v2 " )
319+ else
320+ # Recursive case: restore each field
321+ for p in fields
322+ restore! (
323+ getfield (v1, p),
324+ getfield (v2, p),
325+ comms_ctx;
326+ name = " $(name) .$(p) " ,
327+ ignore,
328+ )
329+ end
330+ end
331+ return nothing
332+ end
333+
334+ """
335+ restore!(
336+ v1::Union{
337+ AbstractTimeVaryingInput,
338+ ClimaComms.AbstractCommsContext,
339+ ClimaComms.AbstractDevice,
340+ UnionAll,
341+ DataType,
342+ },
343+ v2::Union{
344+ AbstractTimeVaryingInput,
345+ ClimaComms.AbstractCommsContext,
346+ ClimaComms.AbstractDevice,
347+ UnionAll,
348+ DataType,
349+ },
350+ _comms_ctx;
351+ name,
352+ ignore,
353+ )
354+
355+ Ignore certain types that don't need to be restored.
356+ `UnionAll` and `DataType` are infinitely recursive, so we also ignore those.
357+ """
358+ function restore! (
359+ v1:: Union {
360+ AbstractTimeVaryingInput,
361+ ClimaComms. AbstractCommsContext,
362+ ClimaComms. AbstractDevice,
363+ UnionAll,
364+ DataType,
365+ },
366+ v2:: Union {
367+ AbstractTimeVaryingInput,
368+ ClimaComms. AbstractCommsContext,
369+ ClimaComms. AbstractDevice,
370+ UnionAll,
371+ DataType,
372+ },
373+ _comms_ctx;
374+ name,
375+ ignore,
376+ )
377+ return nothing
378+ end
379+
380+ """
381+ restore!(
382+ v1::Union{CC.DataLayouts.AbstractData, AbstractArray},
383+ v2::Union{CC.DataLayouts.AbstractData, AbstractArray},
384+ comms_ctx;
385+ name,
386+ ignore,
387+ )
388+
389+ For array-like objects, we move the original data (v2) to the
390+ device of the new data (v1). Then we copy the original data to
391+ the new object.
392+ """
393+ function restore! (
394+ v1:: Union{CC.DataLayouts.AbstractData, AbstractArray} ,
395+ v2:: Union{CC.DataLayouts.AbstractData, AbstractArray} ,
396+ comms_ctx;
397+ name,
398+ ignore,
399+ )
400+ ArrayType =
401+ parent (v1) isa Array ? Array : ClimaComms. array_type (ClimaComms. device (comms_ctx))
402+ moved_to_device = ArrayType (parent (v2))
403+
404+ parent (v1) .= moved_to_device
405+ return nothing
406+ end
407+
408+ """
409+ restore!(
410+ v1::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
411+ v2::Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol},
412+ comms_ctx;
413+ name,
414+ ignore,
415+ )
416+
417+ Ensure that immutable objects have been initialized correctly,
418+ as they cannot be restored from a checkpoint.
419+ """
420+ function restore! (
421+ v1:: Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol} ,
422+ v2:: Union{StaticArrays.StaticArray, Number, UnitRange, LinRange, Symbol} ,
423+ comms_ctx;
424+ name,
425+ ignore,
426+ )
427+ v1 == v2 || error (" $name is a immutable but it inconsistent ($(v1) != $(v2) )" )
428+ return nothing
429+ end
430+
431+ """
432+ restore!(v1::Dict, v2::Dict, comms_ctx; name, ignore)
433+
434+ RRTMGP has some internal dictionaries, which we check for consistency.
435+ """
436+ function restore! (v1:: Dict , v2:: Dict , comms_ctx; name, ignore)
437+ v1 == v2 || error (" $name is inconsistent" )
438+ return nothing
439+ end
440+
441+ """
442+ restore!(
443+ v1::T1,
444+ v2::T2,
445+ comms_ctx;
446+ name,
447+ ignore,
448+ ) where {
449+ T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
450+ T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond},
451+ }
452+
453+ Special case to compare time-related types to allow different timestamps during restore.
454+ """
455+ function restore! (
456+ v1:: T1 ,
457+ v2:: T2 ,
458+ comms_ctx;
459+ name,
460+ ignore,
461+ ) where {
462+ T1 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond} ,
463+ T2 <: Union{Dates.DateTime, Dates.UTInstant, Dates.Millisecond} ,
464+ }
465+ if v1 != v2
466+ @warn " Time value differs in restart" field = name original = v2 new = v1
467+ end
468+ return nothing
469+ end
470+
298471end # module
0 commit comments