Skip to content

Commit 117ad23

Browse files
committed
move restore into Checkpointer
1 parent 6744bfc commit 117ad23

File tree

9 files changed

+345
-162
lines changed

9 files changed

+345
-162
lines changed

docs/src/checkpointer.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Checkpoints are a mix of HDF5 and JLD2 files and are typically saved in a
1212
!!! warning "Known limitations"
1313

1414
- The number of MPI processes has to remain the same across checkpoints
15-
- Restart files are generally not portable across machines, julia versions,
15+
- Restart files are generally not portable across machines, julia versions,
1616
and package versions
1717
- Adding/changing new component models will probably require adding/changing code
1818

@@ -56,7 +56,7 @@ If the model does not support directly reading a checkpoint, the `Checkpointer`
5656
module provides a straightforward way to add this feature.
5757
[`Checkpointer.restart!`](@ref) takes a coupled simulation, a `restart_dir`, and
5858
a `restart_t` and overwrites the content of the coupled simulation with what is
59-
in the checkpoint.
59+
in the checkpoint.
6060

6161
## Developer notes
6262

@@ -135,4 +135,5 @@ Types to watch for:
135135
ClimaCoupler.Checkpointer.restart!
136136
ClimaCoupler.Checkpointer.checkpoint_sims
137137
ClimaCoupler.Checkpointer.t_start_from_checkpoint
138+
ClimaCoupler.Checkpointer.restore!
138139
```

experiments/ClimaEarth/components/atmosphere/climaatmos.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ if pkgversion(CA) < v"0.28.6"
2020
CC.Adapt.@adapt_structure CA.RRTMGPInterface.RRTMGPModel
2121
end
2222

23-
include("../shared/restore.jl")
2423

2524
###
2625
### Functions required by ClimaCoupler.jl for an AtmosModelSimulation
@@ -140,7 +139,7 @@ end
140139

141140
function Checkpointer.restore_cache!(sim::ClimaAtmosSimulation, new_cache)
142141
comms_ctx = ClimaComms.context(sim.integrator.u.c)
143-
restore!(
142+
Checkpointer.restore!(
144143
Checkpointer.get_model_cache(sim),
145144
new_cache,
146145
comms_ctx;
@@ -409,14 +408,14 @@ Interfacer.get_field(sim::ClimaAtmosSimulation, ::Val{:height_sfc}) =
409408
function Interfacer.get_field(sim::ClimaAtmosSimulation, ::Val{:u_int})
410409
# NOTE: This calculation is copied from ClimaAtmos (and is allocating! Fix me if you can!)
411410
int_local_geometry_values =
412-
Fields.level(Fields.local_geometry_field(sim.integrator.u.c), 1)
411+
CC.Fields.level(CC.Fields.local_geometry_field(sim.integrator.u.c), 1)
413412
int_u_values = CC.Spaces.level(sim.integrator.p.precomputed.ᶜu, 1)
414413
return CA.projected_vector_data.(CA.CT1, int_u_values, int_local_geometry_values)
415414
end
416415
function Interfacer.get_field(sim::ClimaAtmosSimulation, ::Val{:v_int})
417416
# NOTE: This calculation is copied from ClimaAtmos (and is allocating! Fix me if you can!)
418417
int_local_geometry_values =
419-
Fields.level(Fields.local_geometry_field(sim.integrator.u.c), 1)
418+
CC.Fields.level(CC.Fields.local_geometry_field(sim.integrator.u.c), 1)
420419
int_u_values = CC.Spaces.level(sim.integrator.p.precomputed.ᶜu, 1)
421420
return CA.projected_vector_data.(CA.CT2, int_u_values, int_local_geometry_values)
422421
end

experiments/ClimaEarth/components/land/climaland_bucket.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import ClimaCoupler: Checkpointer, FluxCalculator, Interfacer, FieldExchanger
1313
using NCDatasets
1414
include("climaland_helpers.jl")
1515

16-
include("../shared/restore.jl")
1716

1817
###
1918
### Functions required by ClimaCoupler.jl for a SurfaceModelSimulation
@@ -411,7 +410,7 @@ end
411410
function Checkpointer.restore_cache!(sim::BucketSimulation, new_cache)
412411
old_cache = Checkpointer.get_model_cache(sim)
413412
comms_ctx = ClimaComms.context(sim.model)
414-
restore!(
413+
Checkpointer.restore!(
415414
old_cache,
416415
new_cache,
417416
comms_ctx,

experiments/ClimaEarth/components/land/climaland_helpers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function make_land_domain(
7575

7676
vertmesh = CC.Meshes.IntervalMesh(
7777
vertdomain,
78-
ClimaCore.Meshes.GeneralizedExponentialStretching{FT}(dz_tuple[1], dz_tuple[2]);
78+
CC.Meshes.GeneralizedExponentialStretching{FT}(dz_tuple[1], dz_tuple[2]);
7979
nelems = nelements_vert,
8080
reverse_mode = true,
8181
)

experiments/ClimaEarth/components/land/climaland_integrated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ end
533533
function Checkpointer.restore_cache!(sim::ClimaLandSimulation, new_cache)
534534
old_cache = Checkpointer.get_model_cache(sim)
535535
comms_ctx = ClimaComms.context(sim.model.soil)
536-
restore!(
536+
Checkpointer.restore!(
537537
old_cache,
538538
new_cache,
539539
comms_ctx,

experiments/ClimaEarth/components/shared/restore.jl

Lines changed: 0 additions & 120 deletions
This file was deleted.

experiments/ClimaEarth/test/compare.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,7 @@
33

44
import ClimaComms
55
import ClimaAtmos as CA
6-
import ClimaCore
7-
import ClimaCore: DataLayouts, Fields, Geometry
8-
import ClimaCore.Fields: Field, FieldVector, field_values
9-
import ClimaCore.DataLayouts: AbstractData
10-
import ClimaCore.Geometry: AxisTensor
11-
import ClimaCore.Spaces: AbstractSpace
6+
import ClimaCore as CC
127
import NCDatasets
138

149
"""
@@ -55,8 +50,8 @@ function compare(
5550
name = "",
5651
ignore = Set([:rc]),
5752
) where {
58-
T1 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache},
59-
T2 <: Union{FieldVector, AbstractSpace, NamedTuple, CA.AtmosCache},
53+
T1 <: Union{CC.Fields.FieldVector, CC.Spaces.AbstractSpace, NamedTuple, CA.AtmosCache},
54+
T2 <: Union{CC.Fields.FieldVector, CC.Spaces.AbstractSpace, NamedTuple, CA.AtmosCache},
6055
}
6156
pass = true
6257
return _compare(pass, v1, v2; name, ignore)
@@ -100,11 +95,16 @@ function _compare(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDat
10095
return pass
10196
end
10297

103-
function _compare(v1::T, v2::T; name, ignore) where {T <: Field{<:AbstractData{<:Real}}}
98+
function _compare(
99+
v1::T,
100+
v2::T;
101+
name,
102+
ignore,
103+
) where {T <: CC.Fields.Field{<:CC.DataLayouts.AbstractData{<:Real}}}
104104
return _compare(parent(v1), parent(v2); name, ignore)
105105
end
106106

107-
function _compare(pass, v1::T, v2::T; name, ignore) where {T <: AbstractData}
107+
function _compare(pass, v1::T, v2::T; name, ignore) where {T <: CC.DataLayouts.AbstractData}
108108
return pass && _compare(parent(v1), parent(v2); name, ignore)
109109
end
110110

0 commit comments

Comments
 (0)