diff --git a/.buildkite/Manifest-v1.11.toml b/.buildkite/Manifest-v1.11.toml index 188416b8739..112b66c5639 100644 --- a/.buildkite/Manifest-v1.11.toml +++ b/.buildkite/Manifest-v1.11.toml @@ -406,9 +406,9 @@ weakdeps = ["CUDA", "MPI"] [[deps.ClimaCore]] deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LazyBroadcast", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "UnrolledUtilities"] -git-tree-sha1 = "ca717446978d2815b4fa23a62a2131861e44d1e8" +git-tree-sha1 = "a2acae071e36c1c69c94a83d1fb74b25e8b0fde0" uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.14.42" +version = "0.14.43" weakdeps = ["CUDA", "Krylov"] [deps.ClimaCore.extensions] @@ -1864,7 +1864,7 @@ version = "2.5.5+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.5+0" +version = "0.8.1+4" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] diff --git a/.buildkite/PrecompileCI/src/PrecompileCI.jl b/.buildkite/PrecompileCI/src/PrecompileCI.jl index 23db0d43c97..c391bdec411 100644 --- a/.buildkite/PrecompileCI/src/PrecompileCI.jl +++ b/.buildkite/PrecompileCI/src/PrecompileCI.jl @@ -3,7 +3,6 @@ module PrecompileCI using PrecompileTools, Logging import ClimaAtmos as CA import ClimaComms -import ClimaCore: InputOutput, Meshes, Spaces, Quadratures import ClimaParams @compile_workload begin @@ -14,34 +13,43 @@ import ClimaParams x_elem = y_elem = 2 x_max = y_max = 1e8 z_max = FT(30000.0) - dz_bottom = FT(500) # other values? - z_stretch = Meshes.HyperbolicTangentStretching(dz_bottom) # Meshes.Uniform() - bubble = true # false - parsed_args = - Dict{String, Any}("topography" => "NoWarp", "topo_smoothing" => false) - comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded()) - deep = false - - # constants - quad = Quadratures.GLL{4}() + dz_bottom = FT(500) + z_stretch = true + bubble = true + nh_poly = 3 # GLL{4} = nh_poly + 1 + # TODO: compile CUDA methods as well + context = ClimaComms.context(ClimaComms.CPUSingleThreaded()) + topography = CA.NoTopography() params = CA.ClimaAtmosParameters(FT) radius = CA.Parameters.planet_radius(params) - # Sphere - horz_mesh = CA.cubed_sphere_mesh(; radius, h_elem) - h_space = CA.make_horizontal_space(horz_mesh, quad, comms_ctx, bubble) - CA.make_hybrid_spaces(h_space, z_max, z_elem, z_stretch; parsed_args) - - # box - horizontal_mesh = CA.periodic_rectangle_mesh(; x_max, y_max, x_elem, y_elem) - h_space = CA.make_horizontal_space(horizontal_mesh, quad, comms_ctx, bubble) - # This is broken - # CA.make_hybrid_spaces(h_space, z_max, z_elem, z_stretch; parsed_args) - - # plane - horizontal_mesh = CA.periodic_line_mesh(; x_max, x_elem) - h_space = CA.make_horizontal_space(horizontal_mesh, quad, comms_ctx, bubble) + sphere_grid = CA.SphereGrid( + FT; + context, + radius, h_elem, nh_poly, + z_elem, z_max, z_stretch, dz_bottom, + bubble, topography, + ) + box_grid = CA.BoxGrid( + FT; + context, + x_elem, x_max, y_elem, y_max, nh_poly, periodic_x = true, periodic_y = true, + z_elem, z_max, z_stretch, dz_bottom, + bubble, topography, + ) + plane_grid = CA.PlaneGrid( + FT; + context, + x_elem, x_max, nh_poly, periodic_x = true, + z_elem, z_max, z_stretch, dz_bottom, + bubble, topography, + ) + column_grid = CA.ColumnGrid( + FT; context, z_elem, z_max, z_stretch, dz_bottom, + ) + all_grids = (sphere_grid, box_grid, plane_grid, column_grid) + foreach(CA.get_spaces, all_grids) end end -end # module Precompile +end # module PrecompileCI diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index dda81676a28..2a19823cfb8 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1208,7 +1208,7 @@ steps: CLIMACOMMS_DEVICE: "CUDA" CLIMA_NAME_CUDA_KERNELS_FROM_STACK_TRACE: "true" agents: - slurm_mem: 24GB + slurm_mem: 28GB slurm_gpus: 1 - group: "Flame graphs" diff --git a/NEWS.md b/NEWS.md index db82d7478b7..148bf1585b0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,9 @@ ClimaAtmos.jl Release Notes main ------- +PR [#4021](https://github.com/CliMA/ClimaAtmos.jl/pull/4021) uses ClimaCore +convenience constructors to create spaces without an AtmosConfig. + v0.31.5 ------- PR [#3975](https://github.com/CliMA/ClimaAtmos.jl/pull/3975) updates the pressure gradient formulation to subtract a reference state and use the Exner pressure. diff --git a/Project.toml b/Project.toml index 35f11f5b7d7..e2dc1e76a76 100644 --- a/Project.toml +++ b/Project.toml @@ -44,12 +44,12 @@ ArgParse = "1" Artifacts = "1" AtmosphericProfilesLibrary = "0.1.7" ClimaComms = "0.6.9" -ClimaCore = "0.14.37" -ClimaDiagnostics = "0.2.12" +ClimaCore = "0.14.43" +ClimaDiagnostics = "0.2.13" ClimaInterpolations = "0.1.0" ClimaParams = "1.0.2" ClimaTimeSteppers = "0.8.2" -ClimaUtilities = "0.1.23" +ClimaUtilities = "0.1.27" CloudMicrophysics = "0.28, 0.29" Dates = "1" ForwardDiff = "1" diff --git a/calibration/experiments/gcm_driven_scm/helper_funcs.jl b/calibration/experiments/gcm_driven_scm/helper_funcs.jl index 6370559c923..54f3aacfad3 100644 --- a/calibration/experiments/gcm_driven_scm/helper_funcs.jl +++ b/calibration/experiments/gcm_driven_scm/helper_funcs.jl @@ -23,8 +23,8 @@ CLIMADIAGNOSTICS_LES_NAME_MAP = """Get z cell centers coordinates for CA run, given config. """ function get_z_grid(atmos_config; z_max = nothing) params = CA.ClimaAtmosParameters(atmos_config) - spaces = - CA.get_spaces(atmos_config.parsed_args, params, atmos_config.comms_ctx) + grid = CA.get_grid(atmos_config.parsed_args, params, atmos_config.comms_ctx) + spaces = CA.get_spaces(grid, atmos_config.comms_ctx) coord = CA.Fields.coordinate_field(spaces.center_space) z_vec = convert(Vector{Float64}, parent(coord.z)[:]) if !isnothing(z_max) diff --git a/docs/src/api.md b/docs/src/api.md index 264468791d9..cda349e74c3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -44,12 +44,21 @@ ClimaAtmos.InitialConditions.Soares ClimaAtmos.InitialConditions.RCEMIPIIProfile ``` -### Helper +## Helper ```@docs ClimaAtmos.InitialConditions.ColumnInterpolatableField ``` +## Grids + +```@docs +ClimaAtmos.ColumnGrid +ClimaAtmos.SphereGrid +ClimaAtmos.PlaneGrid +ClimaAtmos.BoxGrid +``` + ## Jacobian ```@docs @@ -69,6 +78,8 @@ ClimaAtmos.ScharTopography ClimaAtmos.EarthTopography ClimaAtmos.DCMIP200Topography ClimaAtmos.Hughes2023Topography +ClimaAtmos.SLEVEWarp +ClimaAtmos.LinearWarp ``` ### Internals diff --git a/examples/topography_spectra.jl b/examples/topography_spectra.jl index ab01bb6cd02..90407f76381 100644 --- a/examples/topography_spectra.jl +++ b/examples/topography_spectra.jl @@ -14,9 +14,18 @@ import ClimaCoreSpectra: power_spectrum_2d const AA = AtmosArtifacts +using ClimaCore: + Geometry, Domains, Meshes, Topologies, Spaces, Grids, Hypsography, Fields +using ClimaComms +using ClimaUtilities: SpaceVaryingInputs.SpaceVaryingInput -function mask(x::FT) where {FT} - return x * FT(x > 0) +# Include helper functions from test directory +include(joinpath(@__DIR__, "..", "test", "test_helpers.jl")) + +# h_elem is the number of elements per side of every panel (6 panels in total) +function cubed_sphere_mesh(; radius, h_elem) + domain = Domains.SphereDomain(radius) + return Meshes.EquiangularCubedSphere(domain, h_elem) end """ @@ -37,10 +46,10 @@ function generate_spaces(; ) FT = Float32 cubed_sphere_mesh = - CA.cubed_sphere_mesh(; radius = FT(planet_radius), h_elem) + cubed_sphere_mesh(; radius = FT(planet_radius), h_elem) quad = Quadratures.GLL{4}() comms_ctx = ClimaComms.context() - h_space = CA.make_horizontal_space(cubed_sphere_mesh, quad, comms_ctx, true) + h_space = make_horizontal_space(cubed_sphere_mesh, quad, comms_ctx, true) Δh_scale = Spaces.node_horizontal_length_scale(h_space) @assert h_space isa CC.Spaces.SpectralElementSpace2D coords = CC.Fields.coordinate_field(h_space) @@ -49,7 +58,6 @@ function generate_spaces(; "z", h_space, ) - elev_from_file = @. mask(elev_from_file) # Fractional damping of smallest resolved scale # Approximated as k₀ ≈ 1/Δx, with n_attenuation # the factor by which we wish to damp wavenumber @@ -58,7 +66,7 @@ function generate_spaces(; κ = diff_courant * Δh_scale^2 maxiter = Int(round(log(n_attenuation) / diff_courant)) diffuse_surface_elevation!(elev_from_file; κ, dt = FT(1), maxiter) - elev_from_file = @. mask(elev_from_file) + @. elev_from_file = max(elev_from_file, FT(0)) return elev_from_file end diff --git a/post_processing/ci_plots.jl b/post_processing/ci_plots.jl index a69d2242656..50f5e18b7b1 100644 --- a/post_processing/ci_plots.jl +++ b/post_processing/ci_plots.jl @@ -437,8 +437,6 @@ make_plots_generic( simulation_path, vars, time = LAST_SNAP, - x = 0.0, # Our columns are still 3D objects... - y = 0.0, more_kwargs = YLINEARSCALE, ) ``` @@ -453,8 +451,6 @@ make_plots_generic( simulation_path, vars, time = LAST_SNAP, - x = 0.0, # Our columns are still 3D objects... - y = 0.0, more_kwargs = YLINEARSCALE, ) ``` @@ -589,14 +585,23 @@ ColumnPlots = Union{ function make_plots(::ColumnPlots, output_paths::Vector{<:AbstractString}) simdirs = SimDir.(output_paths) short_names = ["ta", "wa"] - vars = map_comparison(get, simdirs, short_names) + vars = map_comparison(simdirs, short_names) do simdir, short_name + var = get(simdir; short_name) + # For vertical-only (FiniteDifferenceGrid) spaces, the data may have + # extra singleton dimensions. Check and squeeze if needed. + if haskey(var.dims, "x") && length(var.dims["x"]) == 1 + var = slice(var; x = var.dims["x"][1]) + end + if haskey(var.dims, "y") && length(var.dims["y"]) == 1 + var = slice(var; y = var.dims["y"][1]) + end + return var + end make_plots_generic( output_paths, vars, time = LAST_SNAP, - x = 0.0, # Our columns are still 3D objects... - y = 0.0, MAX_NUM_COLS = length(simdirs), more_kwargs = YLINEARSCALE, ) @@ -644,7 +649,7 @@ function make_plots( end vars = [ - slice(get(simdir; short_name), x = 0.0, y = 0.0) for + get(simdir; short_name) for short_name in short_names ] @@ -689,7 +694,7 @@ function make_plots( surface_precip = read_var(simdir.variable_paths["pr"]["inst"]["10s"]) viz.line_plot1D!( fig, - slice(surface_precip, x = 0.0, y = 0.0); + surface_precip; p_loc = [pr_row, 1:3], ) @@ -1250,6 +1255,12 @@ EDMFBoxPlots = Union{ Val{:diagnostic_edmfx_dycoms_rf01_box}, Val{:diagnostic_edmfx_trmm_box_0M}, Val{:diagnostic_edmfx_dycoms_rf01_explicit_box}, + Val{:prognostic_edmfx_bomex_box}, + Val{:rcemipii_box_diagnostic_edmfx}, + Val{:diagnostic_edmfx_trmm_stretched_box}, +} + +EDMFColumnPlots = Union{ Val{:prognostic_edmfx_adv_test_column}, Val{:prognostic_edmfx_gabls_column}, Val{:prognostic_edmfx_gabls_column_sparse_autodiff}, @@ -1262,12 +1273,10 @@ EDMFBoxPlots = Union{ Val{:prognostic_edmfx_simpleplume_column}, Val{:prognostic_edmfx_gcmdriven_column}, Val{:prognostic_edmfx_tv_era5driven_column}, - Val{:prognostic_edmfx_bomex_box}, Val{:prognostic_edmfx_soares_column}, - Val{:diagnostic_edmfx_trmm_stretched_box}, } -EDMFBoxPlotsWithPrecip = Union{ +EDMFColumnPlotsWithPrecip = Union{ Val{:prognostic_edmfx_rico_column}, Val{:prognostic_edmfx_rico_implicit_column}, Val{:prognostic_edmfx_rico_column_2M}, @@ -1360,61 +1369,38 @@ end function make_plots( sim_type::Union{ EDMFBoxPlots, - EDMFBoxPlotsWithPrecip, DiagEDMFBoxPlotsWithPrecip, + EDMFColumnPlots, + EDMFColumnPlotsWithPrecip, }, output_paths::Vector{<:AbstractString}, ) simdirs = SimDir.(output_paths) - if sim_type isa EDMFBoxPlotsWithPrecip + # Determine if this is a box or column type + is_box = sim_type isa Union{EDMFBoxPlots, DiagEDMFBoxPlotsWithPrecip} + + # Determine precipitation names based on type + if sim_type isa DiagEDMFBoxPlotsWithPrecip + precip_names = ("husra", "hussn", "husraup", "hussnup") + elseif sim_type isa EDMFColumnPlotsWithPrecip if sim_type isa Val{:prognostic_edmfx_rico_column_2M} precip_names = ( - "husra", - "hussn", - "husraup", - "hussnup", - "husraen", - "hussnen", - "cdnc", - "ncra", - "cdncup", - "ncraup", - "cdncen", - "ncraen", + "husra", "hussn", "husraup", "hussnup", "husraen", "hussnen", + "cdnc", "ncra", "cdncup", "ncraup", "cdncen", "ncraen", ) else precip_names = ("husra", "hussn", "husraup", "hussnup", "husraen", "hussnen") end - elseif sim_type isa DiagEDMFBoxPlotsWithPrecip - precip_names = ("husra", "hussn", "husraup", "hussnup") else precip_names = () end short_names = [ - "wa", - "waup", - "ta", - "taup", - "hus", - "husup", - "arup", - "tke", - "ua", - "thetaa", - "thetaaup", - "ha", - "haup", - "hur", - "hurup", - "lmix", - "cl", - "clw", - "clwup", - "cli", - "cliup", + "wa", "waup", "ta", "taup", "hus", "husup", "arup", "tke", "ua", + "thetaa", "thetaaup", "ha", "haup", "hur", "hurup", "lmix", + "cl", "clw", "clwup", "cli", "cliup", precip_names..., ] reduction = "inst" @@ -1431,13 +1417,13 @@ function make_plots( short_name_tuples = pair_edmf_names(short_names) var_groups_zt = map_comparison(simdirs, short_name_tuples) do simdir, name_tuple - return [ - slice( - get(simdir; short_name, reduction, period), - x = 0.0, - y = 0.0, - ) for short_name in name_tuple - ] + vars = map(short_name -> get(simdir; short_name, reduction, period), name_tuple) + # For box types, slice to a point (x=0, y=0) + if is_box + return map(var -> slice(var, x = 0.0, y = 0.0), vars) + else + return vars + end end var_groups_z = [ @@ -1456,7 +1442,7 @@ function make_plots( make_plots_generic( output_paths, - vcat(var_groups_zt...), + vcat((var_groups_zt...)...), plot_fn = plot_parsed_attribute_title!, summary_files = [tmp_file], MAX_NUM_COLS = 2, diff --git a/reproducibility_tests/ref_counter.jl b/reproducibility_tests/ref_counter.jl index 668ae150e29..04ca424566e 100644 --- a/reproducibility_tests/ref_counter.jl +++ b/reproducibility_tests/ref_counter.jl @@ -1,4 +1,4 @@ -281 +282 # **README** # @@ -20,6 +20,9 @@ #= +282 +- Use ClimaCore.CommonSpaces constructors for Atmos spaces + 281 - Clean up ci, remove some jobs diff --git a/src/ClimaAtmos.jl b/src/ClimaAtmos.jl index a7a8c587d19..7a36124a118 100644 --- a/src/ClimaAtmos.jl +++ b/src/ClimaAtmos.jl @@ -17,7 +17,6 @@ import .Parameters as CAP include(joinpath("utils", "abbreviations.jl")) include(joinpath("utils", "gpu_compat.jl")) -include(joinpath("utils", "common_spaces.jl")) include(joinpath("solver", "types.jl")) include(joinpath("solver", "cli_options.jl")) include(joinpath("utils", "utilities.jl")) @@ -156,6 +155,7 @@ import .Diagnostics as CAD include(joinpath("callbacks", "get_callbacks.jl")) include(joinpath("simulation", "AtmosSimulations.jl")) +include(joinpath("simulation", "grids.jl")) include(joinpath("solver", "model_getters.jl")) # high-level (using parsed_args) model getters include(joinpath("solver", "type_getters.jl")) diff --git a/src/cache/temporary_quantities.jl b/src/cache/temporary_quantities.jl index b5cf4d432bc..6dbed220836 100644 --- a/src/cache/temporary_quantities.jl +++ b/src/cache/temporary_quantities.jl @@ -28,7 +28,6 @@ function temporary_quantities(Y, atmos) center_space, face_space = axes(Y.c), axes(Y.f) FT = Spaces.undertype(center_space) - CTh = CTh_vector_type(Y.c) uvw_vec = UVW(FT(0), FT(0), FT(0)) return (; ᶠtemp_scalar = Fields.Field(FT, face_space), # ᶠp, ᶠρK_h @@ -81,7 +80,7 @@ function temporary_quantities(Y, atmos) # TODO: Remove this hack sfc_temp_C3 = Fields.Field(C3{FT}, Spaces.level(face_space, half)), # ρ_flux_χ # Implicit solver cache: - ∂ᶜK_∂ᶜuₕ = similar(Y.c, DiagonalMatrixRow{Adjoint{FT, CTh{FT}}}), + ∂ᶜK_∂ᶜuₕ = similar(Y.c, DiagonalMatrixRow{Adjoint{FT, CT12{FT}}}), ∂ᶜK_∂ᶠu₃ = similar(Y.c, BidiagonalMatrixRow{Adjoint{FT, CT3{FT}}}), ᶠp_grad_matrix = similar(Y.f, BidiagonalMatrixRow{C3{FT}}), ᶠbidiagonal_matrix_ct3 = similar(Y.f, BidiagonalMatrixRow{CT3{FT}}), diff --git a/src/callbacks/get_callbacks.jl b/src/callbacks/get_callbacks.jl index 8aed9a1a96d..cc7dc343f5b 100644 --- a/src/callbacks/get_callbacks.jl +++ b/src/callbacks/get_callbacks.jl @@ -168,6 +168,7 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, sim_info, output_dir) FT(time_to_seconds(parsed_args["t_end"]) - t_start), start_date; output_writer = netcdf_writer, + topography = has_topography(axes(Y.c)), )..., diagnostics..., ] diff --git a/src/diagnostics/default_diagnostics.jl b/src/diagnostics/default_diagnostics.jl index 50f7f24b96a..9abc1b9d812 100644 --- a/src/diagnostics/default_diagnostics.jl +++ b/src/diagnostics/default_diagnostics.jl @@ -30,6 +30,7 @@ function default_diagnostics( duration, start_date::DateTime; output_writer, + topography = true, ) # Unfortunately, [] is not treated nicely in a map (we would like it to be "excluded"), # so we need to manually filter out the submodels that don't have defaults associated @@ -48,7 +49,7 @@ function default_diagnostics( # We use a map because we want to ensure that diagnostics is a well defined type, not # Any. This reduces latency. return vcat( - core_default_diagnostics(output_writer, duration, start_date), + core_default_diagnostics(output_writer, duration, start_date; topography), map(non_empty_fields) do field default_diagnostics( getfield(model, field), @@ -131,7 +132,7 @@ end ######## # Core # ######## -function core_default_diagnostics(output_writer, duration, start_date) +function core_default_diagnostics(output_writer, duration, start_date; topography = true) core_diagnostics = [ "ts", "ta", @@ -168,22 +169,26 @@ function core_default_diagnostics(output_writer, duration, start_date) min_func = (args...; kwargs...) -> hourly_min(FT, args...; kwargs...) max_func = (args...; kwargs...) -> hourly_max(FT, args...; kwargs...) end + # Base diagnostics for all cases + base_diagnostics = [ + average_func(core_diagnostics...; output_writer, start_date)..., + min_func("ts"; output_writer, start_date), + max_func("ts"; output_writer, start_date), + ] - return [ - # We need to compute the topography at the beginning of the simulation (and only at - # the beginning), so we set output/compute_schedule_func to false. It is still - # computed at the very beginning - ScheduledDiagnostic(; + # Prepend orography diagnostic if topography is enabled + if topography + orog_diagnostic = ScheduledDiagnostic(; variable = get_diagnostic_variable("orog"), output_schedule_func = (integrator) -> false, compute_schedule_func = (integrator) -> false, output_writer, output_short_name = "orog_inst", - ), - average_func(core_diagnostics...; output_writer, start_date)..., - min_func("ts"; output_writer, start_date), - max_func("ts"; output_writer, start_date), - ] + ) + return [orog_diagnostic, base_diagnostics...] + else + return base_diagnostics + end end ################## diff --git a/src/parameterized_tendencies/sponge/viscous_sponge.jl b/src/parameterized_tendencies/sponge/viscous_sponge.jl index fcfc409849f..7b61665eb03 100644 --- a/src/parameterized_tendencies/sponge/viscous_sponge.jl +++ b/src/parameterized_tendencies/sponge/viscous_sponge.jl @@ -13,7 +13,9 @@ import ClimaCore.Spaces as Spaces αₘ(s, z) * ζ_viscous(s, z, zmax) function viscous_sponge_tendency_uₕ(ᶜuₕ, s) - s isa Nothing && return NullBroadcasted() + if s isa Nothing || axes(ᶜuₕ) isa Spaces.FiniteDifferenceSpace + return NullBroadcasted() + end (; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ)) zmax = z_max(axes(ᶠz)) return @. lazy( diff --git a/src/prognostic_equations/advection.jl b/src/prognostic_equations/advection.jl index 90c9446b44e..ba910353d78 100644 --- a/src/prognostic_equations/advection.jl +++ b/src/prognostic_equations/advection.jl @@ -240,7 +240,7 @@ NVTX.@annotate function explicit_vertical_advection_tendency!(Yₜ, Y, p, t) if point_type <: Geometry.Abstract3DPoint @. ᶜω³ = wcurlₕ(Y.c.uₕ) - elseif point_type <: Geometry.Abstract2DPoint + else @. ᶜω³ = zero(ᶜω³) end diff --git a/src/prognostic_equations/hyperdiffusion.jl b/src/prognostic_equations/hyperdiffusion.jl index 3de2286a09c..a09db27d67f 100644 --- a/src/prognostic_equations/hyperdiffusion.jl +++ b/src/prognostic_equations/hyperdiffusion.jl @@ -31,7 +31,6 @@ end function hyperdiffusion_cache( Y, hyperdiff::ClimaHyperdiffusion, turbconv_model, moisture_model, microphysics_model, ) - quadrature_style = Spaces.quadrature_style(Spaces.horizontal_space(axes(Y.c))) FT = eltype(Y) n = n_mass_flux_subdomains(turbconv_model) diff --git a/src/prognostic_equations/implicit/manual_sparse_jacobian.jl b/src/prognostic_equations/implicit/manual_sparse_jacobian.jl index d9d691b5a61..7f2b69b37e3 100644 --- a/src/prognostic_equations/implicit/manual_sparse_jacobian.jl +++ b/src/prognostic_equations/implicit/manual_sparse_jacobian.jl @@ -71,15 +71,14 @@ function jacobian_cache(alg::ManualSparseJacobian, Y, atmos) approximate_solve_iters, ) = alg FT = Spaces.undertype(axes(Y.c)) - CTh = CTh_vector_type(axes(Y.c)) DiagonalRow = DiagonalMatrixRow{FT} TridiagonalRow = TridiagonalMatrixRow{FT} BidiagonalRow_C3 = BidiagonalMatrixRow{C3{FT}} - TridiagonalRow_ACTh = TridiagonalMatrixRow{Adjoint{FT, CTh{FT}}} + TridiagonalRow_ACT12 = TridiagonalMatrixRow{Adjoint{FT, CT12{FT}}} BidiagonalRow_ACT3 = BidiagonalMatrixRow{Adjoint{FT, CT3{FT}}} - BidiagonalRow_C3xACTh = - BidiagonalMatrixRow{typeof(zero(C3{FT}) * zero(CTh{FT})')} + BidiagonalRow_C3xACT12 = + BidiagonalMatrixRow{typeof(zero(C3{FT}) * zero(CT12{FT})')} DiagonalRow_C3xACT3 = DiagonalMatrixRow{typeof(zero(C3{FT}) * zero(CT3{FT})')} TridiagonalRow_C3xACT3 = @@ -150,7 +149,7 @@ function jacobian_cache(alg::ManualSparseJacobian, Y, atmos) MatrixFields.unrolled_map( name -> (name, @name(c.uₕ)) => - similar(Y.c, TridiagonalRow_ACTh), + similar(Y.c, TridiagonalRow_ACT12), active_scalar_names, ) : () )..., @@ -162,7 +161,7 @@ function jacobian_cache(alg::ManualSparseJacobian, Y, atmos) name -> (@name(f.u₃), name) => similar(Y.f, BidiagonalRow_C3), active_scalar_names, )..., - (@name(f.u₃), @name(c.uₕ)) => similar(Y.f, BidiagonalRow_C3xACTh), + (@name(f.u₃), @name(c.uₕ)) => similar(Y.f, BidiagonalRow_C3xACT12), (@name(f.u₃), @name(f.u₃)) => similar(Y.f, TridiagonalRow_C3xACT3), ) @@ -407,7 +406,6 @@ function update_jacobian!(alg::ManualSparseJacobian, cache, Y, p, dtγ, t) rs = p.atmos.rayleigh_sponge FT = Spaces.undertype(axes(Y.c)) - CTh = CTh_vector_type(axes(Y.c)) one_C3xACT3 = C3(FT(1)) * CT3(FT(1))' cv_d = FT(CAP.cv_d(params)) @@ -458,10 +456,10 @@ function update_jacobian!(alg::ManualSparseJacobian, cache, Y, p, dtγ, t) if use_derivative(topography_flag) @. ∂ᶜK_∂ᶜuₕ = DiagonalMatrixRow( - adjoint(CTh(ᶜuₕ)) + adjoint(ᶜinterp(ᶠu₃)) * g³ʰ(ᶜgⁱʲ), + adjoint(CT12(ᶜuₕ)) + adjoint(ᶜinterp(ᶠu₃)) * g³ʰ(ᶜgⁱʲ), ) else - @. ∂ᶜK_∂ᶜuₕ = DiagonalMatrixRow(adjoint(CTh(ᶜuₕ))) + @. ∂ᶜK_∂ᶜuₕ = DiagonalMatrixRow(adjoint(CT12(ᶜuₕ))) end @. ∂ᶜK_∂ᶠu₃ = ᶜinterp_matrix() ⋅ DiagonalMatrixRow(adjoint(CT3(ᶠu₃))) + diff --git a/src/simulation/grids.jl b/src/simulation/grids.jl new file mode 100644 index 00000000000..1dd5849fe61 --- /dev/null +++ b/src/simulation/grids.jl @@ -0,0 +1,335 @@ +import ClimaCore: Geometry, Hypsography, Fields, Spaces, Meshes, Grids, CommonGrids +using ClimaUtilities: SpaceVaryingInputs.SpaceVaryingInput +import .AtmosArtifacts as AA +import ClimaComms + +export SphereGrid, ColumnGrid, BoxGrid, PlaneGrid + +""" + SphereGrid(::Type{FT}; kwargs...) + +Create an ExtrudedCubedSphereGrid with topography support. + +# Arguments +- `FT`: the floating-point type [`Float32`, `Float64`] + +# Keyword Arguments +- `context = ClimaComms.context()`: the ClimaComms communications context +- `z_elem = 10`: the number of z-points +- `z_max = 30000.0`: the domain maximum along the z-direction +- `z_stretch = true`: whether to use vertical stretching +- `dz_bottom = 500.0`: bottom layer thickness for stretching +- `radius = 6.371229e6`: the radius of the cubed sphere +- `h_elem = 6`: the number of horizontal elements per side of every panel (6 + panels in total) +- `nh_poly = 3`: the polynomial order. Note: The number of quadrature points in + 1D within each horizontal element is then `n_quad_points = nh_poly + 1` +- `bubble = false`: enables the "bubble correction" for more accurate element + areas when computing the spectral element space +- `deep_atmosphere = true`: use deep atmosphere equations and metric terms, + otherwise assume columns are cylindrical (shallow atmosphere) +- `topography = NoTopography()`: topography type +- `topography_damping_factor = 5.0`: factor by which smallest resolved + length-scale is to be damped +- `mesh_warp_type = SLEVEWarp{FT}()`: mesh warping type ([`SLEVEWarp`](@ref) or + [`LinearWarp`](@ref)) +- `topo_smoothing = false`: apply topography smoothing +""" +function SphereGrid( + ::Type{FT}; + context = ClimaComms.context(), + z_elem = 10, + z_max = 30000.0, + z_stretch = true, + dz_bottom = 500.0, + radius = 6.371229e6, + h_elem = 6, + nh_poly = 3, + bubble = false, + deep_atmosphere = true, + topography::AbstractTopography = NoTopography(), + topography_damping_factor = 5.0, + mesh_warp_type::MeshWarpType = SLEVEWarp{FT}(), + topo_smoothing = false, +) where {FT} + n_quad_points = nh_poly + 1 + stretch = + z_stretch ? Meshes.HyperbolicTangentStretching{FT}(dz_bottom) : Meshes.Uniform() + hypsography_fun = hypsography_function_from_topography( + FT, topography, topography_damping_factor, mesh_warp_type, topo_smoothing, + ) + + global_geometry = if deep_atmosphere + Geometry.DeepSphericalGlobalGeometry{FT}(radius) + else + Geometry.ShallowSphericalGlobalGeometry{FT}(radius) + end + + grid = CommonGrids.ExtrudedCubedSphereGrid( + FT; + z_elem, z_min = 0, z_max, radius, h_elem, + n_quad_points, + device = ClimaComms.device(context), + context, + stretch, + hypsography_fun, + global_geometry, + enable_bubble = bubble, + ) + + return grid +end + +""" + ColumnGrid(::Type{FT}; kwargs...) + +Create a ColumnGrid. + +# Arguments +- `FT`: the floating-point type [`Float32`, `Float64`] + +# Keyword Arguments +- `context = ClimaComms.context()`: the ClimaComms communications context +- `z_elem = 10`: the number of z-points +- `z_max = 30000.0`: the domain maximum along the z-direction +- `z_stretch = true`: whether to use vertical stretching +- `dz_bottom = 500.0`: bottom layer thickness for stretching +""" +function ColumnGrid( + ::Type{FT}; + context = ClimaComms.context(), + z_elem = 10, + z_max = 30000.0, + z_stretch = true, + dz_bottom = 500.0, +) where {FT} + stretch = + z_stretch ? Meshes.HyperbolicTangentStretching{FT}(dz_bottom) : Meshes.Uniform() + z_mesh = CommonGrids.DefaultZMesh(FT; z_min = 0, z_max, z_elem, stretch) + grid = CommonGrids.ColumnGrid( + FT; + z_elem, z_min = 0, z_max, z_mesh, + device = ClimaComms.device(context), + context, + stretch, + ) + + return grid +end + +""" + BoxGrid(::Type{FT}; kwargs...) + +Create a Box3DGrid with topography support. + +# Arguments +- `FT`: the floating-point type [`Float32`, `Float64`] + +# Keyword Arguments +- `context = ClimaComms.context()`: the ClimaComms communications context +- `x_elem = 6`: the number of x-points +- `x_max = 300000.0`: the domain maximum along the x-direction +- `y_elem = 6`: the number of y-points +- `y_max = 300000.0`: the domain maximum along the y-direction +- `z_elem = 10`: the number of z-points +- `z_max = 30000.0`: the domain maximum along the z-direction +- `nh_poly = 3`: the polynomial order. Note: The number of quadrature points in + 1D within each horizontal element is then `n_quad_points = nh_poly + 1` +- `z_stretch = true`: whether to use vertical stretching +- `dz_bottom = 500.0`: bottom layer thickness for stretching +- `bubble = false`: enables the "bubble correction" for more accurate element + areas when computing the spectral element space. +- `periodic_x = true`: use periodic domain along x-direction +- `periodic_y = true`: use periodic domain along y-direction +- `topography = NoTopography()`: topography type +- `topography_damping_factor = 5.0`: factor by which smallest resolved + length-scale is to be damped +- `mesh_warp_type = LinearWarp()`: mesh warping type ([`SLEVEWarp`](@ref) or + [`LinearWarp`](@ref)) +- `topo_smoothing = false`: apply topography smoothing +""" +function BoxGrid( + ::Type{FT}; + context = ClimaComms.context(), + x_elem = 6, + x_max = 300000.0, + y_elem = 6, + y_max = 300000.0, + z_elem = 10, + z_max = 30000.0, + nh_poly = 3, + z_stretch = true, + dz_bottom = 500.0, + bubble = false, + periodic_x = true, + periodic_y = true, + topography::AbstractTopography = NoTopography(), + topography_damping_factor = 5.0, + mesh_warp_type::MeshWarpType = LinearWarp(), + topo_smoothing = false, +) where {FT} + n_quad_points = nh_poly + 1 + stretch = + z_stretch ? Meshes.HyperbolicTangentStretching{FT}(dz_bottom) : Meshes.Uniform() + hypsography_fun = hypsography_function_from_topography( + FT, topography, topography_damping_factor, mesh_warp_type, topo_smoothing, + ) + z_mesh = CommonGrids.DefaultZMesh(FT; z_min = 0, z_max, z_elem, stretch) + grid = CommonGrids.Box3DGrid( + FT; + z_elem, x_min = 0, x_max, y_min = 0, y_max, z_min = 0, z_max, + periodic_x, periodic_y, n_quad_points, x_elem, y_elem, + device = ClimaComms.device(context), + context, + stretch, + hypsography_fun, + global_geometry = Geometry.CartesianGlobalGeometry(), + z_mesh, + enable_bubble = bubble, + ) + + return grid +end + +""" + PlaneGrid(::Type{FT}; kwargs...) + +Create a SliceXZGrid with topography support. + +# Arguments +- `FT`: the floating-point type [`Float32`, `Float64`] + +# Keyword Arguments +- `context = ClimaComms.context()`: the ClimaComms communications context +- `x_elem = 6`: the number of x-points +- `x_max = 300000.0`: the domain maximum along the x-direction +- `z_elem = 10`: the number of z-points +- `z_max = 30000.0`: the domain maximum along the z-direction +- `nh_poly = 3`: the polynomial order. Note: The number of quadrature points in + 1D within each horizontal element is then `n_quad_points = nh_poly + 1` +- `z_stretch = true`: whether to use vertical stretching +- `dz_bottom = 500.0`: bottom layer thickness for stretching +- `bubble = false`: enables the "bubble correction" for more accurate element + areas when computing the spectral element space. Note: Currently not supported + by SliceXZGrid in ClimaCore. +- `periodic_x = true`: use periodic domain along x-direction +- `topography = NoTopography()`: topography type +- `topography_damping_factor = 5.0`: factor by which smallest resolved + length-scale is to be damped +- `mesh_warp_type = LinearWarp()`: mesh warping type ([`SLEVEWarp`](@ref) or + [`LinearWarp`](@ref)) +- `topo_smoothing = false`: apply topography smoothing +""" +function PlaneGrid( + ::Type{FT}; + context = ClimaComms.context(), + x_elem = 6, + x_max = 300000.0, + z_elem = 10, + z_max = 30000.0, + nh_poly = 3, + z_stretch = true, + dz_bottom = 500.0, + bubble = false, + periodic_x = true, + topography::AbstractTopography = NoTopography(), + topography_damping_factor = 5.0, + mesh_warp_type::MeshWarpType = LinearWarp(), + topo_smoothing = false, +) where {FT} + n_quad_points = nh_poly + 1 + stretch = + z_stretch ? Meshes.HyperbolicTangentStretching{FT}(dz_bottom) : Meshes.Uniform() + hypsography_fun = hypsography_function_from_topography( + FT, topography, topography_damping_factor, mesh_warp_type, topo_smoothing, + ) + z_mesh = CommonGrids.DefaultZMesh(FT; z_min = 0, z_max, z_elem, stretch) + + grid = CommonGrids.SliceXZGrid( + FT; + z_elem, x_elem, x_min = 0, x_max, z_min = 0, z_max, z_mesh, + periodic_x, + n_quad_points, + device = ClimaComms.device(context), + context, + stretch, + hypsography_fun, + global_geometry = Geometry.CartesianGlobalGeometry(), + ) + + return grid +end + +""" + hypsography_function_from_topography( + FT, topography, topography_damping_factor, mesh_warp_type, topo_smoothing) + +Create a hypsography function that handles topography integration. +""" +function hypsography_function_from_topography( + ::Type{FT}, + topography::AbstractTopography, + topography_damping_factor, + mesh_warp_type::MeshWarpType, + topo_smoothing, +) where {FT} + return function hypsography(h_grid, z_grid) + topography isa NoTopography && return Hypsography.Flat() + + # Create horizontal space to work with topography + h_space = if h_grid isa Grids.SpectralElementGrid1D + Spaces.SpectralElementSpace1D(h_grid) + elseif h_grid isa Grids.SpectralElementGrid2D + Spaces.SpectralElementSpace2D(h_grid) + else + error("Unsupported horizontal grid type $(typeof(h_grid))") + end + + # Load topography data + if topography isa EarthTopography + context = ClimaComms.context(h_space) + z_surface = SpaceVaryingInput( + AA.earth_orography_file_path(; context), + "z", + h_space, + ) + @info "Remapping Earth orography from ETOPO2022 data onto horizontal space" + else + z_surface = SpaceVaryingInput(topography_function(topography), h_space) + @info "Using $(nameof(typeof(topography))) orography" + end + + if topography isa EarthTopography + # Diffuse Earth topography to remove small-scale features, using a + # diffusion Courant number (CFL = νΔt/Δx²) to control smoothing + diff_courant = FT(0.05) + Δh_scale = Spaces.node_horizontal_length_scale(h_space) + κ = FT(diff_courant * Δh_scale^2) + maxiter = Int(round(log(topography_damping_factor) / diff_courant)) + # Coefficient for horizontal diffusion may alternatively be + # determined from the empirical parameters suggested by E3SM v1/v2 + # Topography documentation found here: + # https://acme-climate.atlassian.net/wiki/spaces/DOC/pages/1456603764/V1+Topography+GLL+grids + Hypsography.diffuse_surface_elevation!(z_surface; κ, dt = FT(1), maxiter) + @. z_surface = max(z_surface, 0) + elseif topo_smoothing + Hypsography.diffuse_surface_elevation!(z_surface) + end + + # Create hypsography from mesh warp type + if mesh_warp_type isa SLEVEWarp + @info "SLEVE mesh warp (eta=$(mesh_warp_type.eta), s=$(mesh_warp_type.s))" + hypsography = Hypsography.SLEVEAdaption( + Geometry.ZPoint.(z_surface), + FT(mesh_warp_type.eta), + FT(mesh_warp_type.s), + ) + elseif mesh_warp_type isa LinearWarp + @info "Linear mesh warp" + hypsography = Hypsography.LinearAdaption(Geometry.ZPoint.(z_surface)) + else + error("Undefined mesh-warping option $(nameof(typeof(mesh_warp_type)))") + end + return hypsography + end +end diff --git a/src/solver/type_getters.jl b/src/solver/type_getters.jl index b99c0e22c45..01cbcb06129 100644 --- a/src/solver/type_getters.jl +++ b/src/solver/type_getters.jl @@ -6,7 +6,7 @@ import ClimaUtilities.OutputPathGenerator import ClimaCore: InputOutput, Meshes, Spaces, Quadratures import ClimaAtmos.RRTMGPInterface as RRTMGPI import ClimaAtmos as CA -import ClimaCore.Fields +import ClimaCore: Fields, Grids import ClimaTimeSteppers as CTS import Logging @@ -240,127 +240,38 @@ function get_numerics(parsed_args, FT) return numerics end -function get_spaces(parsed_args, params, comms_ctx) +""" + get_spaces(grid) - FT = eltype(params) - z_elem = Int(parsed_args["z_elem"]) - z_max = FT(parsed_args["z_max"]) - dz_bottom = FT(parsed_args["dz_bottom"]) - bubble = parsed_args["bubble"] - deep = parsed_args["deep_atmosphere"] - - h_elem = parsed_args["h_elem"] - radius = CAP.planet_radius(params) - center_space, face_space = if parsed_args["config"] == "sphere" - nh_poly = parsed_args["nh_poly"] - quad = Quadratures.GLL{nh_poly + 1}() - horizontal_mesh = cubed_sphere_mesh(; radius, h_elem) - h_space = - make_horizontal_space(horizontal_mesh, quad, comms_ctx, bubble) - z_stretch = if parsed_args["z_stretch"] - Meshes.HyperbolicTangentStretching(dz_bottom) - else - Meshes.Uniform() - end - make_hybrid_spaces(h_space, z_max, z_elem, z_stretch; deep, parsed_args) - elseif parsed_args["config"] == "column" # single column - @warn "perturb_initstate flag is ignored for single column configuration" - FT = eltype(params) - Δx = FT(1) # Note: This value shouldn't matter, since we only have 1 column. - quad = Quadratures.GL{1}() - horizontal_mesh = periodic_rectangle_mesh(; - x_max = Δx, - y_max = Δx, - x_elem = 1, - y_elem = 1, - ) - if bubble - @warn "Bubble correction not compatible with single column configuration. It will be switched off." - bubble = false - end - h_space = - make_horizontal_space(horizontal_mesh, quad, comms_ctx, bubble) - z_stretch = if parsed_args["z_stretch"] - Meshes.HyperbolicTangentStretching(dz_bottom) - else - Meshes.Uniform() - end - make_hybrid_spaces(h_space, z_max, z_elem, z_stretch; parsed_args) - elseif parsed_args["config"] == "box" - FT = eltype(params) - nh_poly = parsed_args["nh_poly"] - quad = Quadratures.GLL{nh_poly + 1}() - x_elem = Int(parsed_args["x_elem"]) - x_max = FT(parsed_args["x_max"]) - y_elem = Int(parsed_args["y_elem"]) - y_max = FT(parsed_args["y_max"]) - horizontal_mesh = periodic_rectangle_mesh(; - x_max = x_max, - y_max = y_max, - x_elem = x_elem, - y_elem = y_elem, +Create center and face spaces from a ClimaCore grid. +""" +function get_spaces(grid) + if grid isa Grids.ExtrudedFiniteDifferenceGrid + center_space = Spaces.CenterExtrudedFiniteDifferenceSpace(grid) + face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(grid) + elseif grid isa Grids.FiniteDifferenceGrid + center_space = Spaces.CenterFiniteDifferenceSpace(grid) + face_space = Spaces.FaceFiniteDifferenceSpace(grid) + else + error( + """Unsupported grid type: $(typeof(grid)). Expected \ + ExtrudedFiniteDifferenceGrid or FiniteDifferenceGrid""", ) - h_space = - make_horizontal_space(horizontal_mesh, quad, comms_ctx, bubble) - z_stretch = if parsed_args["z_stretch"] - Meshes.HyperbolicTangentStretching(dz_bottom) - else - Meshes.Uniform() - end - make_hybrid_spaces(h_space, z_max, z_elem, z_stretch; parsed_args, deep) - elseif parsed_args["config"] == "plane" - FT = eltype(params) - nh_poly = parsed_args["nh_poly"] - quad = Quadratures.GLL{nh_poly + 1}() - x_elem = Int(parsed_args["x_elem"]) - x_max = FT(parsed_args["x_max"]) - horizontal_mesh = - periodic_line_mesh(; x_max = x_max, x_elem = x_elem) - h_space = - make_horizontal_space(horizontal_mesh, quad, comms_ctx, bubble) - z_stretch = if parsed_args["z_stretch"] - Meshes.HyperbolicTangentStretching(dz_bottom) - else - Meshes.Uniform() - end - make_hybrid_spaces(h_space, z_max, z_elem, z_stretch; parsed_args, deep) - end - ncols = Fields.ncolumns(center_space) - ndofs_total = ncols * z_elem - hspace = Spaces.horizontal_space(center_space) - quad_style = Spaces.quadrature_style(hspace) - Nq = Quadratures.degrees_of_freedom(quad_style) - - @info "Resolution stats: " Nq h_elem z_elem ncols ndofs_total - return (; - center_space, - face_space, - horizontal_mesh, - quad, - z_max, - z_elem, - z_stretch, - ) -end - -function get_spaces_restart(Y) - center_space = axes(Y.c) - face_space = axes(Y.f) + end return (; center_space, face_space) end function get_state_restart(config::AtmosConfig, restart_file, atmos_model_hash) (; parsed_args, comms_ctx) = config - sim_info = get_sim_info(config) + (; start_date) = get_sim_info(config) + use_itime = parsed_args["use_itime"] @assert !isnothing(restart_file) reader = InputOutput.HDF5Reader(restart_file, comms_ctx) Y = InputOutput.read_field(reader, "Y") # TODO: Do not use InputOutput.HDF5 directly t_start = InputOutput.HDF5.read_attribute(reader.file, "time") - t_start = - parsed_args["use_itime"] ? ITime(t_start; epoch = sim_info.start_date) : - t_start + t_start = use_itime ? ITime(t_start; epoch = start_date) : t_start if "atmos_model_hash" in keys(InputOutput.HDF5.attrs(reader.file)) atmos_model_hash_in_restart = InputOutput.HDF5.read_attribute(reader.file, "atmos_model_hash") @@ -630,39 +541,71 @@ function auto_detect_restart_file( return restart_file end -function get_sim_info(config::AtmosConfig) - (; comms_ctx, parsed_args) = config - FT = eltype(config) - (; job_id) = config +import ClimaUtilities.OutputPathGenerator + +""" + setup_output_dir(job_id, output_dir, output_dir_style, detect_restart_file, restart_file, comms_ctx) + +Unified function for setting up output directories and detecting restart files. +Used by both AtmosSimulation constructor and get_simulation. + +Returns a named tuple with: +- `output_dir`: The final output directory path +- `restart_file`: The restart file path (if any) +""" +function setup_output_dir( + job_id, + output_dir, + output_dir_style, + detect_restart_file, + restart_file, + comms_ctx, +) + # Set up base output directory default_output = haskey(ENV, "CI") ? job_id : joinpath("output", job_id) - out_dir = parsed_args["output_dir"] - base_output_dir = isnothing(out_dir) ? default_output : out_dir + base_output_dir = isnothing(output_dir) ? default_output : output_dir allowed_dir_styles = Dict( "activelink" => OutputPathGenerator.ActiveLinkStyle(), "removepreexisting" => OutputPathGenerator.RemovePreexistingStyle(), ) - requested_style = parsed_args["output_dir_style"] + haskey(allowed_dir_styles, lowercase(output_dir_style)) || + error("output_dir_style $(output_dir_style) not available") - haskey(allowed_dir_styles, lowercase(requested_style)) || - error("output_dir_style $(requested_style) not available") + output_dir_style_obj = allowed_dir_styles[lowercase(output_dir_style)] - output_dir_style = allowed_dir_styles[lowercase(requested_style)] - - # We look for a restart before creating a new output dir because we want to - # look for previous folders - restart_file = - parsed_args["detect_restart_file"] ? - auto_detect_restart_file(output_dir_style, base_output_dir) : - parsed_args["restart_file"] + final_restart_file = if detect_restart_file && isnothing(restart_file) + auto_detect_restart_file(output_dir_style_obj, base_output_dir) + else + restart_file + end output_dir = OutputPathGenerator.generate_output_path( base_output_dir; context = comms_ctx, - style = output_dir_style, + style = output_dir_style_obj, ) + + return output_dir, final_restart_file +end + +function get_sim_info(config::AtmosConfig) + (; comms_ctx, parsed_args) = config + FT = eltype(config) + + (; job_id) = config + + output_dir, restart_file = setup_output_dir( + job_id, + parsed_args["output_dir"], + parsed_args["output_dir_style"], + parsed_args["detect_restart_file"], + parsed_args["restart_file"], + comms_ctx, + ) + if parsed_args["log_to_file"] @info "Logging to $output_dir/output.log" logger = ClimaComms.FileLogger(comms_ctx, output_dir) @@ -775,10 +718,94 @@ function get_comms_context(parsed_args) return comms_ctx end +function get_mesh_warp_type(FT, parsed_args) + warp_type_str = parsed_args["mesh_warp_type"] + if warp_type_str == "SLEVE" + return SLEVEWarp{FT}( + eta = parsed_args["sleve_eta"], + s = parsed_args["sleve_s"], + ) + elseif warp_type_str == "Linear" + return LinearWarp() + else + error( + "Unknown mesh warp type string: $warp_type_str. Supported types are 'SLEVE' and 'Linear'", + ) + end +end + +function get_grid(parsed_args, params, context) + FT = eltype(params) + config = parsed_args["config"] + + # Common vertical discretization parameters + kwargs = ( + z_elem = parsed_args["z_elem"], + z_max = parsed_args["z_max"], + z_stretch = parsed_args["z_stretch"], + dz_bottom = parsed_args["dz_bottom"], + ) + + # Add topography parameters for non-column grids + if config != "column" + kwargs = ( + kwargs..., + topography = get_topography(FT, parsed_args), + topography_damping_factor = parsed_args["topography_damping_factor"], + mesh_warp_type = get_mesh_warp_type(FT, parsed_args), + topo_smoothing = parsed_args["topo_smoothing"], + ) + end + + # Grid-specific construction + if config == "sphere" + SphereGrid( + FT; + context, + radius = CAP.planet_radius(params), + h_elem = parsed_args["h_elem"], + nh_poly = parsed_args["nh_poly"], + bubble = parsed_args["bubble"], + deep_atmosphere = parsed_args["deep_atmosphere"], + kwargs..., + ) + elseif config == "column" + ColumnGrid(FT; context, kwargs...) + elseif config == "box" + BoxGrid( + FT; + context, + x_elem = parsed_args["x_elem"], + x_max = parsed_args["x_max"], + y_elem = parsed_args["y_elem"], + y_max = parsed_args["y_max"], + nh_poly = parsed_args["nh_poly"], + bubble = parsed_args["bubble"], + periodic_x = true, + periodic_y = true, + kwargs..., + ) + elseif config == "plane" + PlaneGrid( + FT; + context, + x_elem = parsed_args["x_elem"], + x_max = parsed_args["x_max"], + nh_poly = parsed_args["nh_poly"], + bubble = parsed_args["bubble"], + periodic_x = true, + kwargs..., + ) + end +end + function get_simulation(config::AtmosConfig) sim_info = get_sim_info(config) params = ClimaAtmosParameters(config) atmos = get_atmos(config, params) + comms_ctx = get_comms_context(config.parsed_args) + grid = get_grid(config.parsed_args, params, comms_ctx) + job_id = sim_info.job_id output_dir = sim_info.output_dir @info "Simulation info" job_id output_dir @@ -797,14 +824,16 @@ function get_simulation(config::AtmosConfig) sim_info.restart_file, hash(atmos), ) - spaces = get_spaces_restart(Y) + spaces = (; center_space = axes(Y.c), face_space = axes(Y.f)) # Fix the t_start in sim_info with the one from the restart sim_info = merge(sim_info, (; t_start)) end @info "Allocating Y: $s" else - spaces = get_spaces(config.parsed_args, params, config.comms_ctx) + spaces = get_spaces(grid) end + @info "Simulation Grid: $(spaces.center_space.grid)" + # TODO: add more information about the grid - stretch, etc. initial_condition = get_initial_condition(config.parsed_args, atmos) surface_setup = get_surface_setup(config.parsed_args) if !sim_info.restart @@ -891,7 +920,10 @@ function get_simulation(config::AtmosConfig) accum_str = join(CA.promote_period.(collect(periods_reductions)), ", ") checkpt_str = CA.promote_period(checkpoint_frequency) - @warn "The checkpointing frequency (dt_save_state_to_disk = $checkpt_str) should be an integer multiple of all diagnostics accumulation periods ($accum_str) so simulations can be safely restarted from any checkpoint" + @warn """The checkpointing frequency \ + (dt_save_state_to_disk = $checkpt_str) should be an integer \ + multiple of all diagnostics accumulation periods ($accum_str) \ + so simulations can be safely restarted from any checkpoint""" end end else diff --git a/src/solver/types.jl b/src/solver/types.jl index a16e278bff2..22518e664c4 100644 --- a/src/solver/types.jl +++ b/src/solver/types.jl @@ -1267,7 +1267,7 @@ function AtmosConfig( end """ - maybe_resolve_and_acquire_artifacts(input_str::AbstractString, context::ClimaComms.AbstractCommsContext) + maybe_resolve_and_acquire_artifacts(input_str::AbstractString, context) When given a string of the form `artifact"name"/something/else`, resolve the artifact path and download it (if not already available). @@ -1276,7 +1276,7 @@ In all the other cases, return the input unchanged. """ function maybe_resolve_and_acquire_artifacts( input_str::AbstractString, - context::ClimaComms.AbstractCommsContext, + context, ) matched = match(r"artifact\"([a-zA-Z0-9_]+)\"(\/.*)?", input_str) if isnothing(matched) @@ -1292,20 +1292,20 @@ end function maybe_resolve_and_acquire_artifacts( input, - _::ClimaComms.AbstractCommsContext, + _, ) return input end """ - config_with_resolved_and_acquired_artifacts(input_str::AbstractString, context::ClimaComms.AbstractCommsContext) + config_with_resolved_and_acquired_artifacts(input_str::AbstractString, context) Substitute strings of the form `artifact"name"/something/else` with the actual artifact path. """ function config_with_resolved_and_acquired_artifacts( config::AbstractDict, - context::ClimaComms.AbstractCommsContext, + context, ) return Dict( k => maybe_resolve_and_acquire_artifacts(v, context) for diff --git a/src/surface_conditions/surface_conditions.jl b/src/surface_conditions/surface_conditions.jl index b4afb7d6069..1a47876deea 100644 --- a/src/surface_conditions/surface_conditions.jl +++ b/src/surface_conditions/surface_conditions.jl @@ -381,11 +381,11 @@ end #For non-RCEMIPII box models with prescribed surface temp, assume that the latitude is 0. function surface_temperature( ::ZonallySymmetricSST, - coordinates::Union{Geometry.XZPoint, Geometry.XYZPoint}, + coordinates, surface_temp_params, ) - (; x) = coordinates - FT = eltype(x) + (; z) = coordinates + FT = eltype(z) return FT(300) end diff --git a/src/topography/topography.jl b/src/topography/topography.jl index 16a13fb78ca..0fc1fe9e363 100644 --- a/src/topography/topography.jl +++ b/src/topography/topography.jl @@ -1,4 +1,11 @@ using ClimaCore: Geometry, Spaces, Fields +export CosineTopography, + AgnesiTopography, + ScharTopography, + EarthTopography, + DCMIP200Topography, + Hughes2023Topography, + LinearWarp, SLEVEWarp ## ## Topography profiles for 2D and 3D boxes @@ -147,3 +154,35 @@ function topography_hughes2023(coord) ), ) end + +## +## Mesh warping types for topography +## + +abstract type MeshWarpType end + +""" + LinearWarp() + +Linear mesh warping that uniformly distributes vertical levels between the +surface and top of the domain. +""" +struct LinearWarp <: MeshWarpType end + +""" + SLEVEWarp(; eta = 0.7, s = 10.0) + +Smooth Level Vertical (SLEVE) coordinate warping for terrain-following meshes. + +# Arguments +- `eta`: Threshold parameter (if z/z_top > eta, no warping is applied). Default: 0.7 +- `s`: Decay scale parameter controlling how quickly the warping decays with height. Default: 10.0 + +# References +Schär et al. (2002), "A new terrain-following vertical coordinate formulation +for atmospheric prediction models", Mon. Wea. Rev. +""" +Base.@kwdef struct SLEVEWarp{FT <: AbstractFloat} <: MeshWarpType + eta::FT = 0.7 + s::FT = 10.0 +end diff --git a/src/utils/common_spaces.jl b/src/utils/common_spaces.jl deleted file mode 100644 index ccca9ef2647..00000000000 --- a/src/utils/common_spaces.jl +++ /dev/null @@ -1,175 +0,0 @@ -using ClimaCore: - Geometry, Domains, Meshes, Topologies, Spaces, Grids, Hypsography, Fields -using ClimaComms -using ClimaUtilities: SpaceVaryingInputs.SpaceVaryingInput - -function periodic_line_mesh(; x_max, x_elem) - domain = Domains.IntervalDomain( - Geometry.XPoint(zero(x_max)), - Geometry.XPoint(x_max); - periodic = true, - ) - return Meshes.IntervalMesh(domain; nelems = x_elem) -end - -function periodic_rectangle_mesh(; x_max, y_max, x_elem, y_elem) - x_domain = Domains.IntervalDomain( - Geometry.XPoint(zero(x_max)), - Geometry.XPoint(x_max); - periodic = true, - ) - y_domain = Domains.IntervalDomain( - Geometry.YPoint(zero(y_max)), - Geometry.YPoint(y_max); - periodic = true, - ) - domain = Domains.RectangleDomain(x_domain, y_domain) - return Meshes.RectilinearMesh(domain, x_elem, y_elem) -end - -# h_elem is the number of elements per side of every panel (6 panels in total) -function cubed_sphere_mesh(; radius, h_elem) - domain = Domains.SphereDomain(radius) - return Meshes.EquiangularCubedSphere(domain, h_elem) -end - -function make_horizontal_space( - mesh, - quad, - comms_ctx::ClimaComms.SingletonCommsContext, - bubble, -) - - space = if mesh isa Meshes.AbstractMesh1D - topology = Topologies.IntervalTopology(comms_ctx, mesh) - Spaces.SpectralElementSpace1D(topology, quad) - elseif mesh isa Meshes.AbstractMesh2D - topology = Topologies.Topology2D( - comms_ctx, - mesh, - Topologies.spacefillingcurve(mesh), - ) - Spaces.SpectralElementSpace2D(topology, quad; enable_bubble = bubble) - end - return space -end - -function make_horizontal_space(mesh, quad, comms_ctx, bubble) - if mesh isa Meshes.AbstractMesh1D - error("Distributed mode does not work with 1D horizontal spaces.") - elseif mesh isa Meshes.AbstractMesh2D - topology = Topologies.DistributedTopology2D( - comms_ctx, - mesh, - Topologies.spacefillingcurve(mesh), - ) - space = Spaces.SpectralElementSpace2D( - topology, - quad; - enable_bubble = bubble, - ) - end - return space -end - -function make_hybrid_spaces( - h_space, - z_max, - z_elem, - z_stretch; - deep = false, - parsed_args = nothing, -) - FT = eltype(z_max) - h_grid = Spaces.grid(h_space) - z_domain = Domains.IntervalDomain( - Geometry.ZPoint(zero(z_max)), - Geometry.ZPoint(z_max); - boundary_names = (:bottom, :top), - ) - z_mesh = Meshes.IntervalMesh(z_domain, z_stretch; nelems = z_elem) - @info "z heights" z_mesh.faces - device = ClimaComms.device(h_space) - z_topology = Topologies.IntervalTopology( - ClimaComms.SingletonCommsContext(device), - z_mesh, - ) - z_grid = Grids.FiniteDifferenceGrid(z_topology) - - topography = get_topography(FT, parsed_args) - if topography isa NoTopography - z_surface = zeros(h_space) - @info "No surface orography warp applied" - elseif topography isa EarthTopography - z_surface = SpaceVaryingInput( - AA.earth_orography_file_path(; - context = ClimaComms.context(h_space), - ), - "z", - h_space, - ) - @info "Remapping Earth orography from ETOPO2022 data onto horizontal space" - else - z_surface = SpaceVaryingInput(topography_function(topography), h_space) - @info "Using $(nameof(typeof(topography))) orography" - end - - if topography isa NoTopography - hypsography = Hypsography.Flat() - elseif topography isa EarthTopography - mask(x::FT) where {FT} = x * FT(x > 0) - z_surface = @. mask(z_surface) - # diff_cfl = νΔt/Δx² - diff_courant = 0.05 # Arbitrary example value. - Δh_scale = Spaces.node_horizontal_length_scale(h_space) - κ = FT(diff_courant * Δh_scale^2) - n_attenuation = parsed_args["topography_damping_factor"] - maxiter = Int(round(log(n_attenuation) / diff_courant)) - Hypsography.diffuse_surface_elevation!( - z_surface; - κ, - dt = FT(1), - maxiter, - ) - # Coefficient for horizontal diffusion may alternatively be - # determined from the empirical parameters suggested by - # E3SM v1/v2 Topography documentation found here: - # https://acme-climate.atlassian.net/wiki/spaces/DOC/pages/1456603764/V1+Topography+GLL+grids - z_surface = @. mask(z_surface) - if parsed_args["mesh_warp_type"] == "SLEVE" - @info "SLEVE mesh warp" - hypsography = Hypsography.SLEVEAdaption( - Geometry.ZPoint.(z_surface), - FT(parsed_args["sleve_eta"]), - FT(parsed_args["sleve_s"]), - ) - elseif parsed_args["mesh_warp_type"] == "Linear" - @info "Linear mesh warp" - hypsography = - Hypsography.LinearAdaption(Geometry.ZPoint.(z_surface)) - else - @error "Undefined mesh-warping option" - end - else - if parsed_args["topo_smoothing"] - Hypsography.diffuse_surface_elevation!(z_surface) - end - if parsed_args["mesh_warp_type"] == "SLEVE" - @info "SLEVE mesh warp" - hypsography = Hypsography.SLEVEAdaption( - Geometry.ZPoint.(z_surface), - FT(parsed_args["sleve_eta"]), - FT(parsed_args["sleve_s"]), - ) - elseif parsed_args["mesh_warp_type"] == "Linear" - @info "Linear mesh warp" - hypsography = - Hypsography.LinearAdaption(Geometry.ZPoint.(z_surface)) - end - end - - grid = Grids.ExtrudedFiniteDifferenceGrid(h_grid, z_grid, hypsography; deep) - center_space = Spaces.CenterExtrudedFiniteDifferenceSpace(grid) - face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(grid) - return center_space, face_space -end diff --git a/src/utils/utilities.jl b/src/utils/utilities.jl index bad5af103a5..a2b18eafce2 100644 --- a/src/utils/utilities.jl +++ b/src/utils/utilities.jl @@ -5,6 +5,7 @@ import ClimaComms import ClimaCore: Spaces, Topologies, Fields, Geometry import LinearAlgebra: norm_sqr using Dates: DateTime, @dateformat_str +import StaticArrays: SVector, SMatrix is_energy_var(symbol) = symbol in (:ρe_tot, :ρae_tot) is_momentum_var(symbol) = symbol in (:uₕ, :ρuₕ, :u₃, :ρw) @@ -222,41 +223,28 @@ Extracts the `g³ʰ` sub-tensor from the `gⁱʲ` tensor. """ function g³ʰ(gⁱʲ) full_CT_axis = axes(gⁱʲ)[1] - CTh_axis = if full_CT_axis == Geometry.Contravariant123Axis() - Geometry.Contravariant12Axis() - elseif full_CT_axis == Geometry.Contravariant13Axis() - Geometry.Contravariant1Axis() - elseif full_CT_axis == Geometry.Contravariant23Axis() - Geometry.Contravariant2Axis() - else - error("$full_CT_axis is missing either vertical or horizontal sub-axes") - end N = length(full_CT_axis) - return Geometry.AxisTensor( - (Geometry.Contravariant3Axis(), CTh_axis), - view(Geometry.components(gⁱʲ), N:N, 1:(N - 1)), - ) -end - -""" - CTh_vector_type(space) - -Extracts the (abstract) horizontal contravariant vector type from the given -`AbstractSpace`. -""" -function CTh_vector_type(space) - full_CT_axis = axes(eltype(Fields.local_geometry_field(space).gⁱʲ))[1] - return if full_CT_axis == Geometry.Contravariant123Axis() - Geometry.Contravariant12Vector + gⁱʲ_components = Geometry.components(gⁱʲ) + FT = eltype(gⁱʲ_components) + g³ʰ_components = if full_CT_axis == Geometry.Contravariant123Axis() + @inbounds SMatrix{1, 2, FT, 2}( + gⁱʲ_components[N, 1], + gⁱʲ_components[N, 2], + ) elseif full_CT_axis == Geometry.Contravariant13Axis() - Geometry.Contravariant1Vector + @inbounds val = gⁱʲ_components[N, 1] + SMatrix{1, 2, FT, 2}(val, zero(FT)) elseif full_CT_axis == Geometry.Contravariant23Axis() - Geometry.Contravariant2Vector + @inbounds val = gⁱʲ_components[N, 1] + SMatrix{1, 2, FT, 2}(zero(FT), val) else error("$full_CT_axis is missing either vertical or horizontal sub-axes") end + axes_tuple = (Geometry.Contravariant3Axis(), Geometry.Contravariant12Axis()) + return Geometry.AxisTensor(axes_tuple, g³ʰ_components) end +has_topography(space::Spaces.FiniteDifferenceSpace) = false has_topography(space) = Spaces.grid(space).hypsography != Spaces.Grids.Flat() """ @@ -350,6 +338,9 @@ function do_dss(space::Spaces.AbstractSpace) Quadratures.GLL end +function do_dss(::Spaces.FiniteDifferenceSpace) + return false +end using ClimaComms is_distributed(::ClimaComms.SingletonCommsContext) = false @@ -581,18 +572,8 @@ function parse_date(date_str) ) end -function iscolumn(space) - # TODO: Our columns are 2+1D boxes with one element at the base. Fix this - isbox = - Meshes.domain(Spaces.topology(Spaces.horizontal_space(space))) isa - Domains.RectangleDomain - isbox || return false - has_one_element = - Meshes.nelements( - Spaces.topology(Spaces.horizontal_space(space)).mesh, - ) == 1 - has_one_element && return true -end +iscolumn(space::Spaces.FiniteDifferenceSpace) = true +iscolumn(space) = false function issphere(space) return Meshes.domain(Spaces.topology(Spaces.horizontal_space(space))) isa diff --git a/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_3d.jl b/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_3d.jl index 14504e653a2..783c2d2df02 100644 --- a/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_3d.jl +++ b/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_3d.jl @@ -16,14 +16,7 @@ include( ) include("../gw_plotutils.jl") -comms_ctx = ClimaComms.SingletonCommsContext() -(; config_file, job_id) = CA.commandline_kwargs() -config = CA.AtmosConfig(config_file; job_id, comms_ctx) - -config.parsed_args["topography"] = "Earth"; -config.parsed_args["topo_smoothing"] = false; -config.parsed_args["mesh_warp_type"] = "Linear"; -(; parsed_args) = config +context = ClimaComms.SingletonCommsContext() # load gfdl data include(joinpath(@__DIR__, "../../../artifact_funcs.jl")) @@ -86,12 +79,23 @@ z_elem = 33 dz_bottom = 300.0 radius = 6.371229e6 -quad = Quadratures.GLL{nh_poly + 1}() -horizontal_mesh = CA.cubed_sphere_mesh(; radius, h_elem) -h_space = CA.make_horizontal_space(horizontal_mesh, quad, comms_ctx, false) -z_stretch = Meshes.HyperbolicTangentStretching(dz_bottom) -center_space, face_space = - CA.make_hybrid_spaces(h_space, z_max, z_elem, z_stretch; parsed_args) +grid = CA.SphereGrid( + FT; + context, + z_elem, + z_max, + z_stretch = true, + dz_bottom, + radius, + h_elem, + nh_poly, + bubble = false, + topography = CA.EarthTopography(), + topography_damping_factor = 5, + mesh_warp_type = CA.LinearWarp(), + topo_smoothing = false, +) +(; center_space, face_space) = CA.get_spaces(grid) ᶜlocal_geometry = Fields.local_geometry_field(center_space) ᶠlocal_geometry = Fields.local_geometry_field(face_space) diff --git a/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_baseflux.jl b/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_baseflux.jl index 0196e7e619f..c4df1a6ebc3 100644 --- a/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_baseflux.jl +++ b/test/parameterized_tendencies/gravity_wave/orographic_gravity_wave/ogwd_baseflux.jl @@ -12,10 +12,7 @@ include( joinpath(pkgdir(ClimaAtmos), "post_processing/remap", "remap_helpers.jl"), ) -comms_ctx = ClimaComms.SingletonCommsContext() -(; config_file, job_id) = CA.commandline_kwargs() -config = CA.AtmosConfig(config_file; job_id, comms_ctx) -config.parsed_args["topography"] = "NoWarp" +context = ClimaComms.SingletonCommsContext() # Create meshes and spaces h_elem = 6 @@ -24,18 +21,20 @@ z_max = 30e3 z_elem = 1 radius = 6.371229e6 -quad = Quadratures.GLL{nh_poly + 1}() -horizontal_mesh = CA.cubed_sphere_mesh(; radius, h_elem) -h_space = CA.make_horizontal_space(horizontal_mesh, quad, comms_ctx, false) -z_stretch = Meshes.Uniform() -center_space, face_space = CA.make_hybrid_spaces( - h_space, - z_max, +grid = CA.SphereGrid( + FT; + context, z_elem, - z_stretch; - parsed_args = config.parsed_args, + z_max, + z_stretch = false, + radius, + h_elem, + nh_poly, + bubble = false, + topography = CA.NoTopography(), ) - +(; center_space, face_space) = CA.get_spaces(grid) +h_space = Spaces.horizontal_space(center_space) ᶜlocal_geometry = Fields.local_geometry_field(center_space) ᶠlocal_geometry = Fields.local_geometry_field(face_space) diff --git a/test/test_helpers.jl b/test/test_helpers.jl index 10092d87b0c..1a453c14840 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -1,6 +1,7 @@ ### BoilerPlate Code using IntervalSets +import ClimaComms import ClimaCore: ClimaCore, Domains, @@ -28,6 +29,104 @@ function generate_test_simulation(config) return (; Y = Y, p = p, params = p.params, simulation = simulation) end +function periodic_line_mesh(; x_max, x_elem) + domain = Domains.IntervalDomain( + Geometry.XPoint(zero(x_max)), + Geometry.XPoint(x_max); + periodic = true, + ) + return Meshes.IntervalMesh(domain; nelems = x_elem) +end + +function periodic_rectangle_mesh(; x_max, y_max, x_elem, y_elem) + x_domain = Domains.IntervalDomain( + Geometry.XPoint(zero(x_max)), + Geometry.XPoint(x_max); + periodic = true, + ) + y_domain = Domains.IntervalDomain( + Geometry.YPoint(zero(y_max)), + Geometry.YPoint(y_max); + periodic = true, + ) + domain = Domains.RectangleDomain(x_domain, y_domain) + return Meshes.RectilinearMesh(domain, x_elem, y_elem) +end + +""" + make_horizontal_space(mesh, quad, comms_ctx::ClimaComms.SingletonCommsContext, bubble) + +Create a horizontal spectral element space from a mesh and quadrature. + +For 1D meshes, creates a `SpectralElementSpace1D`. +For 2D meshes, creates a `SpectralElementSpace2D` with optional bubble correction. + +# Arguments +- `mesh`: The horizontal mesh (1D or 2D) +- `quad`: The quadrature style +- `comms_ctx`: Communications context (must be `SingletonCommsContext` for 1D meshes) +- `bubble`: Enable bubble correction for 2D spaces + +# Returns +- A horizontal spectral element space +""" +function make_horizontal_space( + mesh, + quad, + comms_ctx::ClimaComms.SingletonCommsContext, + bubble, +) + space = if mesh isa Meshes.AbstractMesh1D + topology = Topologies.IntervalTopology(comms_ctx, mesh) + Spaces.SpectralElementSpace1D(topology, quad) + elseif mesh isa Meshes.AbstractMesh2D + topology = Topologies.Topology2D( + comms_ctx, + mesh, + Topologies.spacefillingcurve(mesh), + ) + Spaces.SpectralElementSpace2D(topology, quad; enable_bubble = bubble) + end + return space +end + +""" + make_horizontal_space(mesh, quad, comms_ctx, bubble) + +Create a horizontal spectral element space from a mesh and quadrature (distributed version). + +For distributed contexts, only 2D meshes are supported. + +# Arguments +- `mesh`: The horizontal mesh (must be 2D for distributed contexts) +- `quad`: The quadrature style +- `comms_ctx`: Communications context (distributed) +- `bubble`: Enable bubble correction + +# Returns +- A horizontal spectral element space + +# Throws +- `ErrorException` if a 1D mesh is provided (distributed mode doesn't support 1D spaces) +""" +function make_horizontal_space(mesh, quad, comms_ctx, bubble) + if mesh isa Meshes.AbstractMesh1D + error("Distributed mode does not work with 1D horizontal spaces.") + elseif mesh isa Meshes.AbstractMesh2D + topology = Topologies.DistributedTopology2D( + comms_ctx, + mesh, + Topologies.spacefillingcurve(mesh), + ) + space = Spaces.SpectralElementSpace2D( + topology, + quad; + enable_bubble = bubble, + ) + end + return space +end + function get_spherical_spaces(; FT = Float32) context = ClimaComms.SingletonCommsContext() radius = FT(10π) diff --git a/test/utilities.jl b/test/utilities.jl index 27b35cecf4a..4dcb029ae82 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -298,7 +298,7 @@ end @testset "interval domain" begin # Interval Spaces (; zlim, velem) = get_cartesian_spaces() - line_mesh = CA.periodic_line_mesh(; x_max = zlim[2], x_elem = velem) + line_mesh = periodic_line_mesh(; x_max = zlim[2], x_elem = velem) @test line_mesh isa Meshes.IntervalMesh @test Geometry.XPoint(zlim[1]) == Meshes.domain(line_mesh).coord_min @test Geometry.XPoint(zlim[2]) == Meshes.domain(line_mesh).coord_max @@ -308,7 +308,7 @@ end @testset "periodic rectangle meshes (spectral elements)" begin # Interval Spaces (; xlim, zlim, velem, helem, npoly) = get_cartesian_spaces() - rectangle_mesh = CA.periodic_rectangle_mesh(; + rectangle_mesh = periodic_rectangle_mesh(; x_max = xlim[2], y_max = xlim[2], x_elem = helem, @@ -329,14 +329,14 @@ end comms_ctx = ClimaComms.context(device) FT = eltype(xlim) # 1D Space - line_mesh = CA.periodic_line_mesh(; x_max = zlim[2], x_elem = velem) + line_mesh = periodic_line_mesh(; x_max = zlim[2], x_elem = velem) @test line_mesh isa Meshes.AbstractMesh1D horz_plane_space = - CA.make_horizontal_space(line_mesh, quad, comms_ctx, true) + make_horizontal_space(line_mesh, quad, comms_ctx, true) @test Spaces.column(horz_plane_space, 1, 1) isa Spaces.PointSpace # 2D Space - rectangle_mesh = CA.periodic_rectangle_mesh(; + rectangle_mesh = periodic_rectangle_mesh(; x_max = xlim[2], y_max = xlim[2], x_elem = helem, @@ -344,7 +344,7 @@ end ) @test rectangle_mesh isa Meshes.AbstractMesh2D horz_plane_space = - CA.make_horizontal_space(rectangle_mesh, quad, comms_ctx, true) + make_horizontal_space(rectangle_mesh, quad, comms_ctx, true) @test Spaces.nlevels(horz_plane_space) == 1 @test Spaces.node_horizontal_length_scale(horz_plane_space) == FT(π / npoly / 5) @@ -354,30 +354,26 @@ end @testset "make hybrid spaces" begin (; cent_space, face_space, xlim, zlim, velem, helem, npoly, quad) = get_cartesian_spaces() - config = CA.AtmosConfig( - Dict("topography" => "NoWarp", "topo_smoothing" => false), - ) device = ClimaComms.CPUSingleThreaded() - comms_ctx = ClimaComms.context(device) - z_stretch = Meshes.Uniform() - rectangle_mesh = CA.periodic_rectangle_mesh(; - x_max = xlim[2], - y_max = xlim[2], + context = ClimaComms.context(device) + grid = CA.BoxGrid( + Float32; + context, x_elem = helem, + x_max = xlim[2], y_elem = helem, + y_max = xlim[2], + z_elem = velem, + z_max = zlim[2], + nh_poly = npoly, + z_stretch = false, + bubble = true, + periodic_x = true, + periodic_y = true, ) - horz_plane_space = - CA.make_horizontal_space(rectangle_mesh, quad, comms_ctx, true) - test_cent_space, test_face_space = CA.make_hybrid_spaces( - horz_plane_space, - zlim[2], - velem, - z_stretch; - deep = false, - parsed_args = config.parsed_args, - ) - @test test_cent_space == cent_space - @test test_face_space == face_space + (; center_space, face_space) = CA.get_spaces(grid) + @test center_space == cent_space + @test face_space == face_space end @testset "promote_period" begin